diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4f45a1d0143..499cde368df 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -806,7 +806,7 @@ jobs: -DBUILD_TESTING=ON \ -DCMAKE_C_COMPILER_LAUNCHER=ccache \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache - cmake --build . -j$(nproc) --target oneflow_deps of_cfgobj of_protoobj of_functional_obj of_functional_tensor_obj + cmake --build . -j$(nproc) --target oneflow_deps of_cfgobj of_protoobj of_functional_obj of_functional_tensor_obj of_op_schema - name: Fetch upstream if: ${{ !fromJSON(steps.save-cache.outputs.cache-hit) && github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name }} run: | diff --git a/CMakeLists.txt b/CMakeLists.txt index 4cdeade7db9..232d8827014 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -35,6 +35,9 @@ option(WITH_OPENVINO "Option to build with OpenVINO" OFF) option(WITH_MLIR "" OFF) option(WITH_MLIR_CUDA_CODEGEN "" OFF) set(LLVM_PROVIDER "in-tree" CACHE STRING "in-tree, install") +if (NOT WITH_MLIR) + set(LLVM_PROVIDER "install" CACHE STRING "in-tree will build LLVM's ALL, not what we want when not building MLIR" FORCE) +endif(NOT WITH_MLIR) option(WITH_COCOAPI "Option to build with COCO API" ON) option(WITH_ZLIB "" ON) option(BUILD_GIT_VERSION "" ON) @@ -223,6 +226,12 @@ if(BUILD_CPP_API) endif(BUILD_SHARED_LIBS) endif(BUILD_CPP_API) +set(INJA_URL https://github.com/pantor/inja/archive/refs/tags/v3.3.0.zip CACHE STRING "") +use_mirror(VARIABLE INJA_URL URL ${INJA_URL}) +set(INJA_MD5 611e6b7206d0fb89728a3879f78b4775 CACHE STRING "") +set(JSON_URL https://github.com/nlohmann/json/releases/download/v3.7.3/include.zip CACHE STRING "") +use_mirror(VARIABLE JSON_URL URL ${JSON_URL}) +set(JSON_MD5 fb96f95cdf609143e998db401ca4f324 CACHE STRING "") include(third_party) if (BUILD_CUDA) @@ -275,3 +284,4 @@ add_custom_target(oneflow_deps ALL DEPENDS prepare_oneflow_third_party) if (ONEFLOW) include(oneflow) endif() +add_subdirectory(ci) diff --git a/ci/CMakeLists.txt b/ci/CMakeLists.txt new file mode 100644 index 00000000000..552439ebc59 --- /dev/null +++ b/ci/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(test) diff --git a/ci/test/CMakeLists.txt b/ci/test/CMakeLists.txt new file mode 100644 index 00000000000..1f7871e80a7 --- /dev/null +++ b/ci/test/CMakeLists.txt @@ -0,0 +1,25 @@ +set(PYTHON_EXECUTABLE python3 CACHE STRING "python3 exe to run test, usually is the python3 installation oneflow is linked to") +set(ONEFLOW_SRC_DIR ${CMAKE_SOURCE_DIR} CACHE STRING "source dir of oneflow") +set(IS_DEV ON CACHE BOOL "") +set(CTEST_RESOURCE_SPEC_FILE "${CMAKE_CURRENT_SOURCE_DIR}/resource-spec/2x-rtx-2080.json" CACHE STRING "") + +# CTEST_OUTPUT_ON_FAILURE=1 CTEST_PARALLEL_LEVEL=20 ninja test + +file(GLOB_RECURSE PYTHON_TEST_FILES LIST_DIRECTORIES false RELATIVE ${ONEFLOW_SRC_DIR} "${ONEFLOW_SRC_DIR}/python/oneflow/test_*.py") +foreach(PYTHON_TEST_FILE ${PYTHON_TEST_FILES}) + set(TEST_NAME ${PYTHON_TEST_FILE}) + add_test(NAME ${TEST_NAME} + COMMAND ${PYTHON_EXECUTABLE} ${ONEFLOW_SRC_DIR}/${PYTHON_TEST_FILE} --failfast --verbose + ) + set_tests_properties(${TEST_NAME} + PROPERTIES + ENVIRONMENT "$<$>:ONEFLOW_TEST_CPU_ONLY=1>;$<$:PYTHONPATH=${ONEFLOW_SRC_DIR}/python:$ENV{PYTHONPATH}>" + RESOURCE_GROUPS + "vram:2000" + ) +endforeach() +set_tests_properties(python/oneflow/test/modules/test_rnn.py + PROPERTIES + RESOURCE_GROUPS + "vram:4000" +) diff --git a/ci/test/resource-spec/1x-gtx-1080.json b/ci/test/resource-spec/1x-gtx-1080.json new file mode 100644 index 00000000000..81f888431bf --- /dev/null +++ b/ci/test/resource-spec/1x-gtx-1080.json @@ -0,0 +1,16 @@ +{ + "version": { + "major": 1, + "minor": 0 + }, + "local": [ + { + "vram": [ + { + "id": "0", + "slots": 8117 + } + ] + } + ] +} diff --git a/ci/test/resource-spec/2x-rtx-2080.json b/ci/test/resource-spec/2x-rtx-2080.json new file mode 100644 index 00000000000..a1e44586957 --- /dev/null +++ b/ci/test/resource-spec/2x-rtx-2080.json @@ -0,0 +1,20 @@ +{ + "version": { + "major": 1, + "minor": 0 + }, + "local": [ + { + "vram": [ + { + "id": "0", + "slots": 7982 + }, + { + "id": "1", + "slots": 7982 + } + ] + } + ] +} diff --git a/ci/test/resource-spec/4x-rtx-2080ti.json b/ci/test/resource-spec/4x-rtx-2080ti.json new file mode 100644 index 00000000000..aa401817598 --- /dev/null +++ b/ci/test/resource-spec/4x-rtx-2080ti.json @@ -0,0 +1,28 @@ +{ + "version": { + "major": 1, + "minor": 0 + }, + "local": [ + { + "vram": [ + { + "id": "0", + "slots": 11019 + }, + { + "id": "1", + "slots": 11019 + }, + { + "id": "2", + "slots": 11019 + }, + { + "id": "3", + "slots": 11019 + } + ] + } + ] +} diff --git a/cmake/oneflow.cmake b/cmake/oneflow.cmake index cca18432dfe..662dcc15e9f 100644 --- a/cmake/oneflow.cmake +++ b/cmake/oneflow.cmake @@ -150,6 +150,7 @@ add_custom_target(of_format COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/ci/check/run_license_format.py -i ${ONEFLOW_PYTHON_DIR} --fix --exclude="oneflow/include" --exclude="oneflow/core" COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/ci/check/run_clang_format.py --source_dir ${CMAKE_CURRENT_SOURCE_DIR}/oneflow --fix --quiet COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/ci/check/run_py_format.py --source_dir ${CMAKE_CURRENT_SOURCE_DIR} --fix + COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/ci/check/run_clang_format.py --source_dir ${CMAKE_CURRENT_SOURCE_DIR}/tools/oneflow-tblgen --fix --quiet ) # clang tidy add_custom_target(of_tidy @@ -244,6 +245,7 @@ oneflow_add_library(oneflow ${of_all_obj_cc}) add_dependencies(oneflow of_protoobj) add_dependencies(oneflow of_cfgobj) add_dependencies(oneflow of_functional_obj) +add_dependencies(oneflow of_op_schema) add_dependencies(oneflow of_git_version) if (USE_CLANG_FORMAT) @@ -255,38 +257,33 @@ endif() target_compile_definitions(oneflow PRIVATE GOOGLE_LOGGING) -set(ONEFLOW_TOOLS_DIR "${PROJECT_BINARY_DIR}/tools") -oneflow_add_executable(oneflow-gen-ods EXCLUDE_FROM_ALL ${PROJECT_SOURCE_DIR}/oneflow/ir/oneflow-gen-ods/oneflow-gen-ods.cpp) -set_target_properties(oneflow-gen-ods PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${ONEFLOW_TOOLS_DIR}") +set(ONEFLOW_TOOLS_DIR "${PROJECT_BINARY_DIR}/tools" CACHE STRING "dir to put binary for debugging and development") set(LLVM_MONO_REPO_URL "https://github.com/llvm/llvm-project/archive/649d95371680cbf7f740c990c0357372c2bd4058.zip" CACHE STRING "") use_mirror(VARIABLE LLVM_MONO_REPO_URL URL ${LLVM_MONO_REPO_URL}) set(LLVM_MONO_REPO_MD5 "9bda804e5cc61899085fb0f0dce1089f" CACHE STRING "") set(ONEFLOW_BUILD_ROOT_DIR "${PROJECT_BINARY_DIR}") +add_subdirectory(${PROJECT_SOURCE_DIR}/oneflow/ir) if (WITH_MLIR) - add_subdirectory(${PROJECT_SOURCE_DIR}/oneflow/ir) set(ONEFLOW_MLIR_LIBS -Wl,--no-as-needed MLIROneFlowExtension -Wl,--as-needed) endif() +include(op_schema) + if(APPLE) - set(of_libs -Wl,-force_load oneflow of_protoobj of_cfgobj of_functional_obj) + set(of_libs -Wl,-force_load oneflow of_protoobj of_cfgobj of_functional_obj of_op_schema) target_link_libraries(oneflow of_protoobj of_cfgobj of_functional_obj glog_imported gflags_imported ${oneflow_third_party_libs}) elseif(UNIX) - set(of_libs -Wl,--whole-archive oneflow of_protoobj of_cfgobj of_functional_obj -Wl,--no-whole-archive -ldl -lrt) + set(of_libs -Wl,--whole-archive oneflow of_protoobj of_cfgobj of_functional_obj of_op_schema -Wl,--no-whole-archive -ldl -lrt) target_link_libraries(oneflow of_protoobj of_cfgobj of_functional_obj glog_imported gflags_imported ${oneflow_third_party_libs} -Wl,--no-whole-archive -ldl -lrt) if(BUILD_CUDA) target_link_libraries(oneflow CUDA::cudart_static) endif() elseif(WIN32) - set(of_libs oneflow of_protoobj of_cfgobj of_functional_obj) + set(of_libs oneflow of_protoobj of_cfgobj of_functional_obj of_op_schema) set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} /WHOLEARCHIVE:oneflow") endif() -target_link_libraries(oneflow-gen-ods ${of_libs} ${oneflow_third_party_libs} ${oneflow_exe_third_party_libs}) -if (BUILD_CUDA) - target_link_libraries(oneflow-gen-ods CUDA::cudart_static) -endif() - if(BUILD_PYTHON) # py ext lib @@ -301,7 +298,7 @@ if(BUILD_PYTHON) pybind11_add_module(oneflow_internal ${PYBIND11_SRCS} ${of_pybind_obj_cc} ${PYBIND_REGISTRY_CC}) set_compile_options_to_oneflow_target(oneflow_internal) set_property(TARGET oneflow_internal PROPERTY CXX_VISIBILITY_PRESET "default") - add_dependencies(oneflow_internal of_cfgobj of_functional_obj of_functional_tensor_obj) + add_dependencies(oneflow_internal of_cfgobj of_functional_obj of_functional_tensor_obj of_op_schema) set_target_properties(oneflow_internal PROPERTIES PREFIX "_") set_target_properties(oneflow_internal PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${ONEFLOW_PYTHON_DIR}/oneflow") target_link_libraries(oneflow_internal PRIVATE @@ -340,7 +337,7 @@ if(BUILD_PYTHON) endif(BUILD_PYTHON) if (BUILD_CPP_API) - file(GLOB_RECURSE of_cpp_api_files + file(GLOB_RECURSE of_cpp_api_files ${PROJECT_SOURCE_DIR}/oneflow/api/cpp/*.cpp ${PROJECT_SOURCE_DIR}/oneflow/api/cpp/*.h) if(BUILD_MONOLITHIC_LIBONEFLOW_CPP_SO) @@ -362,6 +359,11 @@ function(oneflow_add_test target_name) endif() set_target_properties(${target_name} PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${PROJECT_BINARY_DIR}/bin") add_test(NAME ${arg_TEST_NAME} COMMAND ${target_name}) + set_tests_properties( + ${arg_TEST_NAME} + PROPERTIES + ENVIRONMENT "HTTP_PROXY='';HTTPS_PROXY='';http_proxy='';https_proxy='';" + ) endfunction() # build test diff --git a/cmake/op_schema.cmake b/cmake/op_schema.cmake new file mode 100644 index 00000000000..970910f94c9 --- /dev/null +++ b/cmake/op_schema.cmake @@ -0,0 +1,90 @@ +get_property(LLVM_INSTALL_DIR GLOBAL PROPERTY LLVM_INSTALL_DIR) +set(LLVM_INSTALL_DIR ${THIRD_PARTY_DIR}/llvm) +set(LLVM_DIR ${LLVM_INSTALL_DIR}/lib/cmake/llvm) +set(ONEFLOW_OP_GROUPS + "ASSIGN" + "BINARY" + "BROADCAST" + "CONV" + "CROSS_ENTROPY" + "CUDA" + "DATASET" + "DETECTION" + "EAGER" + "FUSED" + "IDEMPOTENT" + "IDENTITY" + "IMAGE" + "INDICES" + "INVOLUTION" + "LOSS" + "MATH" + "MATMUL" + "MISC" + "NCCL" + "NORMALIZATION" + "OPTIMIZER" + "PADDING" + "PARALLEL_CAST" + "POOL" + "QUANTIZATION" + "REDUCE" + "RESHAPE" + "SCALAR" + "SOFTMAX" + "SUMMARY" + "TENSOR_BUFFER" + "TEST" + "TRIGONOMETRIC" + "UNARY" + "UPSAMPLE" +) +foreach (OP_GROUP_NAME IN LISTS ONEFLOW_OP_GROUPS) + list(APPEND ONEFLOW_SCHEMA_TABLEGEN_FLAGS "-DGET_ONEFLOW_${OP_GROUP_NAME}_OP_DEFINITIONS") +endforeach() +list(APPEND ONEFLOW_SCHEMA_TABLEGEN_FLAGS "-DREMOVE_ONEFLOW_MLIR_ONLY_OP_DEFINITIONS") + +set(GENERATED_OP_SCHEMA_DIR oneflow/core/framework) +set(GENERATED_IR_INCLUDE_DIR oneflow/ir/include) +set(SOURCE_IR_INCLUDE_DIR ${PROJECT_SOURCE_DIR}/oneflow/ir/include) +set(ONEFLOW_ODS ${SOURCE_IR_INCLUDE_DIR}/OneFlow/OneFlowOps.td) + +list(APPEND ONEFLOW_SCHEMA_TABLEGEN_FLAGS "-I${GENERATED_IR_INCLUDE_DIR}") +list(APPEND ONEFLOW_SCHEMA_TABLEGEN_FLAGS "-I${SOURCE_IR_INCLUDE_DIR}") +list(APPEND ONEFLOW_SCHEMA_TABLEGEN_FLAGS "-I${LLVM_INSTALL_DIR}/include") + +set(GENERATED_OP_SCHEMA_H "${GENERATED_OP_SCHEMA_DIR}/op_generated.h") +set(GENERATED_OP_SCHEMA_CPP "${GENERATED_OP_SCHEMA_DIR}/op_generated.cpp") + + +set(ONEFLOW_TABLE_GEN_EXE ${LLVM_INSTALL_DIR}/bin/oneflow_tblgen) +if(LLVM_PROVIDER STREQUAL "in-tree") + set(ONEFLOW_TABLE_GEN_TARGET oneflow_tblgen install-oneflow-tblgen install-mlir-headers) +elseif(LLVM_PROVIDER STREQUAL "install") + set(ONEFLOW_TABLE_GEN_TARGET ${ONEFLOW_TABLE_GEN_EXE}) +endif() + +file(GLOB_RECURSE ODS_FILES LIST_DIRECTORIES false "${SOURCE_IR_INCLUDE_DIR}/*.td") +if(NOT ODS_FILES) + message(FATAL_ERROR "ODS_FILES not found: ${ODS_FILES}") +endif() +add_custom_command( + OUTPUT ${GENERATED_OP_SCHEMA_H} ${GENERATED_OP_SCHEMA_CPP} + COMMAND ${CMAKE_COMMAND} + ARGS -E make_directory ${GENERATED_OP_SCHEMA_DIR} + COMMAND ${ONEFLOW_TABLE_GEN_EXE} + ARGS --gen-op-schema-h ${ONEFLOW_ODS} ${ONEFLOW_SCHEMA_TABLEGEN_FLAGS} -o ${GENERATED_OP_SCHEMA_H} + COMMAND ${ONEFLOW_TABLE_GEN_EXE} + ARGS --gen-op-schema-cpp ${ONEFLOW_ODS} ${ONEFLOW_SCHEMA_TABLEGEN_FLAGS} + --op-include ${GENERATED_OP_SCHEMA_H} -o ${GENERATED_OP_SCHEMA_CPP} + DEPENDS ${ONEFLOW_TABLE_GEN_TARGET} + ${ODS_FILES} + VERBATIM +) +set_source_files_properties( + ${GENERATED_OP_SCHEMA_H} ${GENERATED_OP_SCHEMA_CPP} PROPERTIES GENERATED TRUE +) + +oneflow_add_library(of_op_schema OBJECT ${GENERATED_OP_SCHEMA_H} ${GENERATED_OP_SCHEMA_CPP}) +add_dependencies(of_op_schema of_cfgobj) +add_dependencies(of_op_schema prepare_oneflow_third_party) diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index 8a8edd31ba9..5875acce69c 100644 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -219,6 +219,7 @@ list(APPEND ONEFLOW_THIRD_PARTY_INCLUDE_DIRS ${COCOAPI_INCLUDE_DIR} ${HALF_INCLUDE_DIR} ${JSON_INCLUDE_DIR} + ${INJA_INCLUDE_DIR} ${ABSL_INCLUDE_DIR} ${OPENSSL_INCLUDE_DIR} ${FLATBUFFERS_INCLUDE_DIR} diff --git a/cmake/third_party/json.cmake b/cmake/third_party/json.cmake index ee497029860..0774fe2696d 100644 --- a/cmake/third_party/json.cmake +++ b/cmake/third_party/json.cmake @@ -23,6 +23,7 @@ if(THIRD_PARTY) ) add_custom_target(json_create_header_dir COMMAND ${CMAKE_COMMAND} -E make_directory ${JSON_INCLUDE_DIR} + COMMAND ${CMAKE_COMMAND} -E make_directory ${JSON_INCLUDE_DIR}/nlohmann DEPENDS json ) add_custom_target(json_copy_headers_to_destination @@ -31,6 +32,7 @@ if(THIRD_PARTY) foreach(header_file ${JSON_HEADERS}) add_custom_command(TARGET json_copy_headers_to_destination PRE_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${JSON_INCLUDE_DIR} + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${JSON_INCLUDE_DIR}/nlohmann ) endforeach() endif(THIRD_PARTY) diff --git a/oneflow/core/eager/opkernel_instruction_type_test.cpp b/oneflow/core/eager/opkernel_instruction_type_test.cpp index ba85b1358cc..20447f684ce 100644 --- a/oneflow/core/eager/opkernel_instruction_type_test.cpp +++ b/oneflow/core/eager/opkernel_instruction_type_test.cpp @@ -210,7 +210,6 @@ TEST(OpkernelInstructionType, call_opkernel) { TEST(OpkernelInstructionType, consecutive_opkernel_calls) { vm::TestResourceDescScope resource_scope(1, 1); InstructionMsgList list; - int64_t in_id = vm::TestUtil::NewStringSymbol(&list, "in_0"); int64_t out_id = vm::TestUtil::NewStringSymbol(&list, "out_0"); int64_t tmp_buffer_id = vm::TestUtil::NewStringSymbol(&list, "tmp_buffer_0"); int64_t test_source_id = 0; @@ -245,28 +244,30 @@ TEST(OpkernelInstructionType, consecutive_opkernel_calls) { op_conf->set_name("ccrelu_op_name"); auto* user_conf = op_conf->mutable_user_conf(); user_conf->set_op_type_name("ccrelu"); - (*user_conf->mutable_input())["in"].add_s("ccrelu_op_name/in_0"); - (*user_conf->mutable_output())["out"].add_s("ccrelu_op_name/out_0"); + (*user_conf->mutable_input())["x"].add_s("ccrelu_op_name/x_0"); + (*user_conf->mutable_output())["y"].add_s("ccrelu_op_name/y_0"); ccrelu_id = InitOpKernelObject(&list, std::make_shared(), op_conf, "gpu"); } int64_t y = 0; + int64_t x_id = vm::TestUtil::NewStringSymbol(&list, "x_0"); + int64_t y_id = vm::TestUtil::NewStringSymbol(&list, "y_0"); { int64_t y_parallel_desc_id = 0; y = vm::TestUtil::NewObject(&list, "gpu", "0:0", &y_parallel_desc_id); int64_t tmp_buffer = vm::TestUtil::NewObject(&list, "gpu", "0:0", &y_parallel_desc_id); int64_t op_node_signature_id = - NewOpNodeSignature(&list, {"in_0"}, {x_parallel_desc_id}, {"out_0", "tmp_buffer_0"}, + NewOpNodeSignature(&list, {"x_0"}, {x_parallel_desc_id}, {"y_0", "tmp_buffer_0"}, {y_parallel_desc_id, y_parallel_desc_id}); list.EmplaceBack(vm::NewInstruction("gpu.CallOpKernel") ->add_parallel_desc(y_parallel_desc_id) ->add_mut_operand(ccrelu_id) ->add_symbol_operand(op_node_signature_id) ->add_separator() - ->add_symbol_operand(in_id) + ->add_symbol_operand(x_id) ->add_const_operand(x) ->add_separator() ->add_separator() - ->add_symbol_operand(out_id) + ->add_symbol_operand(y_id) ->add_symbol_operand(tmp_buffer_id) ->add_mut_operand(y) ->add_mut_operand(tmp_buffer) @@ -365,22 +366,23 @@ TEST(OpkernelInstructionType, consecutive_stateless_call_opkernel) { ->add_symbol_operand(out_id) ->add_mut_operand(x) ->add_separator()); - int64_t in_id = vm::TestUtil::NewStringSymbol(&list, "in_0"); + int64_t x_id = vm::TestUtil::NewStringSymbol(&list, "x_0"); + int64_t y_id = vm::TestUtil::NewStringSymbol(&list, "y_0"); int64_t ccrelu_id = 0; { auto op_conf = std::make_shared(); op_conf->set_name("ccrelu_op_name"); auto* user_conf = op_conf->mutable_user_conf(); user_conf->set_op_type_name("ccrelu"); - (*user_conf->mutable_input())["in"].add_s("ccrelu_op_name/in_0"); - (*user_conf->mutable_output())["out"].add_s("ccrelu_op_name/out_0"); + (*user_conf->mutable_input())["x"].add_s("ccrelu_op_name/x_0"); + (*user_conf->mutable_output())["y"].add_s("ccrelu_op_name/y_0"); ccrelu_id = NewOpConfSymbol(&list, op_conf); } int64_t y_parallel_desc_id = 0; int64_t y = vm::TestUtil::NewObject(&list, "gpu", "0:0", &y_parallel_desc_id); int64_t tmp_buffer = vm::TestUtil::NewObject(&list, "gpu", "0:0", &y_parallel_desc_id); op_node_signature_id = - NewOpNodeSignature(&list, {"in_0"}, {parallel_desc_id}, {"out_0", "tmp_buffer_0"}, + NewOpNodeSignature(&list, {"x_0"}, {parallel_desc_id}, {"y_0", "tmp_buffer_0"}, {y_parallel_desc_id, y_parallel_desc_id}); list.EmplaceBack(vm::NewInstruction("gpu.compute.UserStatelessCallOpKernel") ->add_parallel_desc(y_parallel_desc_id) @@ -389,11 +391,11 @@ TEST(OpkernelInstructionType, consecutive_stateless_call_opkernel) { ->add_symbol_operand(op_node_signature_id) ->add_mut_operand(opkernel_id) ->add_separator() - ->add_symbol_operand(in_id) + ->add_symbol_operand(x_id) ->add_const_operand(x) ->add_separator() ->add_separator() - ->add_symbol_operand(out_id) + ->add_symbol_operand(y_id) ->add_symbol_operand(tmp_buffer_id) ->add_mut_operand(y) ->add_mut_operand(tmp_buffer) diff --git a/oneflow/core/framework/attr_value.cpp b/oneflow/core/framework/attr_value.cpp index a30e107ffa3..f24d8077e15 100644 --- a/oneflow/core/framework/attr_value.cpp +++ b/oneflow/core/framework/attr_value.cpp @@ -19,17 +19,29 @@ namespace oneflow { template const T& AttrValueCast(const user_op::AttrVal& attr_val) { - const auto* typed_attr = dynamic_cast*>(&attr_val); + const auto* typed_attr = dynamic_cast*>(&attr_val); return CHECK_NOTNULL(typed_attr)->val(); } +template +std::shared_ptr CastAttrValue(const T& attr_val) { + return std::make_shared>(attr_val); +} + +template +std::shared_ptr CastAttrValue(const T* attr_val) { + return std::make_shared>(attr_val); +} + template size_t HashTypedAttrVal(const T& val) { return std::hash()(val); } -#define INITIALIZE_ATTR_VALUE_CAST(field, T, attr_type) \ - template const T& AttrValueCast(const user_op::AttrVal& attr_val); \ +#define INITIALIZE_ATTR_VALUE_CAST(field, T, attr_type) \ + template const T& AttrValueCast(const user_op::AttrVal& attr_val); \ + template std::shared_ptr CastAttrValue(const T& attr_val); \ + template std::shared_ptr CastAttrValue(const T* attr_val); \ template size_t HashTypedAttrVal(const T& attr_val); OF_PP_FOR_EACH_TUPLE(INITIALIZE_ATTR_VALUE_CAST, ATTR_SEQ) diff --git a/oneflow/core/framework/attr_value.h b/oneflow/core/framework/attr_value.h index b02a379cecc..d7b67757cea 100644 --- a/oneflow/core/framework/attr_value.h +++ b/oneflow/core/framework/attr_value.h @@ -97,19 +97,25 @@ class AttrVal { }; template -class TypedAttrVal final : public AttrVal { +class TypedAttrValIf : public AttrVal { public: - TypedAttrVal(T v) : val_(v) {} - ~TypedAttrVal() = default; + virtual const T& val() const = 0; + size_t hash_value() const override { return std::hash()(val()); } - size_t hash_value() const override { return std::hash()(val_); } bool operator==(const AttrVal& other) const override { - auto* that = dynamic_cast*>(&other); + auto* that = dynamic_cast*>(&other); if (that == nullptr) { return false; } - return this->val_ == that->val_; + return this->val() == that->val(); } +}; + +template +class TypedAttrVal final : public TypedAttrValIf { + public: + TypedAttrVal(T v) : val_(v) {} + ~TypedAttrVal() = default; - const T& val() const { return val_; } + const T& val() const override { return val_; } private: OF_DISALLOW_COPY_AND_MOVE(TypedAttrVal); @@ -117,11 +123,31 @@ class TypedAttrVal final : public AttrVal { T val_; }; +template +class TypedAttrValRef final : public TypedAttrValIf { + public: + TypedAttrValRef(const T* v) : val_(v) {} + ~TypedAttrValRef() = default; + + const T& val() const override { return *val_; } + + private: + OF_DISALLOW_COPY_AND_MOVE(TypedAttrValRef); + + const T* val_; +}; + } // namespace user_op template const T& AttrValueCast(const user_op::AttrVal& val); +template +std::shared_ptr CastAttrValue(const T& attr_val); + +template +std::shared_ptr CastAttrValue(const T* attr_val); + } // namespace oneflow #endif // ONEFLOW_CORE_FRAMEWORK_ATTR_VALUE_H_ diff --git a/oneflow/core/framework/attr_value_accessor.cpp b/oneflow/core/framework/attr_value_accessor.cpp index 79a63238840..2031b701f71 100644 --- a/oneflow/core/framework/attr_value_accessor.cpp +++ b/oneflow/core/framework/attr_value_accessor.cpp @@ -176,8 +176,8 @@ Maybe MakeCppAttrValueFromProtoOrCfgAttrValue(const ProtoT& cfg_attr_va // clang-format off #define MAKE_ENTRY(field, cpp_type, attr_type) \ } \ - else if (dynamic_cast*>(&cpp_attr_value) != nullptr) { \ - const auto* ptr = dynamic_cast*>(&cpp_attr_value); \ + else if (dynamic_cast*>(&cpp_attr_value) != nullptr) { \ + const auto* ptr = dynamic_cast*>(&cpp_attr_value); \ AttrValueAccessor::Attr(ptr->val(), attr_value); OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, ATTR_SEQ); #undef MAKE_ENTRY diff --git a/oneflow/core/framework/op_attrs.cpp b/oneflow/core/framework/op_attrs.cpp new file mode 100644 index 00000000000..7ae08477890 --- /dev/null +++ b/oneflow/core/framework/op_attrs.cpp @@ -0,0 +1,58 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/op_attrs.h" +#include "oneflow/core/common/util.h" +#include "oneflow/core/framework/op_interp_ctx.h" + +namespace oneflow { + +size_t OpAttrs::count(const std::string& attr_name) const { + return ctx_->AttrNames().count(attr_name); +} + +Maybe OpAttrs::at(const std::string& attr_name) const { return ctx_->GetAttr(attr_name); } +Maybe OpAttrs::operator[](const std::string& attr_name) const { + return ctx_->GetAttr(attr_name); +} + +OpAttrs::const_iterator OpAttrs::begin() const { + const auto& attrs = ctx_->AttrNames(); + return const_iterator(attrs.cbegin(), attrs.cend(), this); +} +OpAttrs::const_iterator OpAttrs::end() const { + const auto& attrs = ctx_->AttrNames(); + return const_iterator(attrs.cend(), attrs.cend(), this); +} + +bool OpAttrs::operator==(const OpAttrs& other) const { + // TODO(hjchen2): Compare each attribute + return ctx_ == other.ctx_; +} + +} // namespace oneflow + +namespace std { + +size_t hash::operator()(const oneflow::OpAttrs& attrs) const { + size_t hash_val = 0; + for (const auto& it : attrs) { + oneflow::AddHash(&hash_val, it.first); + oneflow::HashCombine(&hash_val, it.second->hash_value()); + } + return hash_val; +} + +} // namespace std diff --git a/oneflow/core/framework/op_attrs.h b/oneflow/core/framework/op_attrs.h new file mode 100644 index 00000000000..46f6df71f18 --- /dev/null +++ b/oneflow/core/framework/op_attrs.h @@ -0,0 +1,102 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FRAMEWORK_OP_ATTRS_H_ +#define ONEFLOW_CORE_FRAMEWORK_OP_ATTRS_H_ + +#include +#include + +#include "oneflow/core/common/hash_container.h" +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/framework/attr_value.h" + +namespace oneflow { + +using user_op::AttrVal; + +class OpInterpCtx; + +class OpAttrs { + public: + explicit OpAttrs(const OpInterpCtx* ctx) : ctx_(ctx) {} + + size_t count(const std::string& attr_name) const; + + template + Maybe at(const std::string& attr_name) { + return AttrValueCast(*JUST(this->at(attr_name))); + } + Maybe at(const std::string& attr_name) const; + Maybe operator[](const std::string& attr_name) const; + + class const_iterator { + public: + using bucket_iter = HashSet::const_iterator; + using reference = const std::pair>&; + using pointer = const std::pair>*; + + const_iterator() = default; + const_iterator(bucket_iter pos, bucket_iter limit, const OpAttrs* self) + : pos_(pos), limit_(limit), self_(self) { + CHECK_JUST(UpdateKV()); + } + reference operator*() const { return kv_; } + pointer operator->() const { return &kv_; } + + const_iterator& operator++() { + pos_++; + CHECK_JUST(UpdateKV()); + return *this; + } + bool operator==(const const_iterator& x) const { return pos_ == x.pos_ && self_ == x.self_; } + bool operator!=(const const_iterator& x) const { return !(*this == x); } + + private: + Maybe UpdateKV() { + if (pos_ != limit_) { + kv_.first = *pos_; + kv_.second = JUST(self_->at(*pos_)); + } + return Maybe::Ok(); + } + + bucket_iter pos_; + bucket_iter limit_; + const OpAttrs* self_; + std::pair> kv_; + }; + + const_iterator begin() const; + const_iterator end() const; + + bool operator==(const OpAttrs& other) const; + + private: + const OpInterpCtx* ctx_; +}; + +} // namespace oneflow + +namespace std { + +template<> +struct hash { + size_t operator()(const oneflow::OpAttrs& attrs) const; +}; + +} // namespace std + +#endif // ONEFLOW_CORE_FRAMEWORK_OP_ATTRS_H_ diff --git a/oneflow/core/framework/op_base.h b/oneflow/core/framework/op_base.h new file mode 100644 index 00000000000..0f74faf5153 --- /dev/null +++ b/oneflow/core/framework/op_base.h @@ -0,0 +1,55 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FRAMEWORK_OP_BASE_H_ +#define ONEFLOW_CORE_FRAMEWORK_OP_BASE_H_ + +#include + +#include "oneflow/core/common/hash_container.h" +#include "oneflow/core/common/maybe.h" + +namespace oneflow { + +namespace user_op { +class AttrVal; +} // namespace user_op +using AttrVal = user_op::AttrVal; + +class OpBase { + public: + virtual ~OpBase() = default; + + virtual Maybe GetAttr(const std::string& attr_name) const = 0; + + virtual const HashSet& AttrNames() const { + static const HashSet attr_names; + return attr_names; + } + + protected: + OpBase() = default; +}; + +class FakeOp : public OpBase { + public: + Maybe GetAttr(const std::string& attr_name) const override { + return Error::RuntimeError() << "`FakeOp` has no attribute."; + } +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_FRAMEWORK_OP_BASE_H_ diff --git a/oneflow/core/framework/op_interp_ctx.cpp b/oneflow/core/framework/op_interp_ctx.cpp new file mode 100644 index 00000000000..8e950fbf78f --- /dev/null +++ b/oneflow/core/framework/op_interp_ctx.cpp @@ -0,0 +1,66 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/op_interp_ctx.h" +#include "oneflow/core/framework/attr_value.h" + +namespace oneflow { + +Maybe OpInterpCtx::GetAttr(const std::string& attr_name) const { + return op_->GetAttr(attr_name); +} + +template +Maybe OpInterpCtx::GetAttr(const std::string& attr_name) const { + const auto& attr_val = JUST(this->GetAttr(attr_name)); + if (const auto* ptr = dynamic_cast*>(attr_val.get())) { + return ptr->val(); + } + return Error::RuntimeError() << "Invalid type for attribute " << attr_name; +} + +OpAttrs OpInterpCtx::GetAttrs() const { return OpAttrs(this); } + +template +Maybe OpInterpCtx::SetAttr(const std::string& attr_name, const T& attr_val) { + *const_cast(&JUST(this->GetAttr(attr_name))) = attr_val; + return Maybe::Ok(); +} + +#define INSTANCE_ATTR_GETTER_AND_SETTER(field, T, attr_type) \ + template Maybe OpInterpCtx::GetAttr(const std::string& attr_name) const; \ + template Maybe OpInterpCtx::SetAttr(const std::string& attr_name, const T& attr_val); + +OF_PP_FOR_EACH_TUPLE(INSTANCE_ATTR_GETTER_AND_SETTER, ATTR_SEQ) +#undef INSTANCE_ATTR_GETTER_AND_SETTER + +Maybe OpInterpCtx::SetAttr(const std::string& attr_name, const AttrVal& attr_val) { +#define MAKE_ENTRY(field, cpp_type, attr_type) \ + if (const auto* ptr = dynamic_cast*>(&attr_val)) { \ + return this->SetAttr(attr_name, ptr->val()); \ + } + + OF_PP_FOR_EACH_TUPLE(MAKE_ENTRY, ATTR_SEQ); +#undef MAKE_ENTRY + return Error::RuntimeError() << "Invalid type for attribute " << attr_name; +} + +bool OpInterpCtx::HasAttr(const std::string& attr_name) const { + return AttrNames().count(attr_name) > 0; +} + +const HashSet& OpInterpCtx::AttrNames() const { return op_->AttrNames(); } + +} // namespace oneflow diff --git a/oneflow/core/framework/op_interp_ctx.h b/oneflow/core/framework/op_interp_ctx.h new file mode 100644 index 00000000000..771766f6c8e --- /dev/null +++ b/oneflow/core/framework/op_interp_ctx.h @@ -0,0 +1,73 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FRAMEWORK_OP_INTERP_CTX_H_ +#define ONEFLOW_CORE_FRAMEWORK_OP_INTERP_CTX_H_ + +#include + +#include "oneflow/core/common/hash_container.h" +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/symbol.h" +#include "oneflow/core/framework/attr_value.h" +#include "oneflow/core/framework/nd_sbp.h" +#include "oneflow/core/framework/op_attrs.h" +#include "oneflow/core/framework/op_base.h" +#include "oneflow/core/job/parallel_desc.h" +#include "oneflow/core/job/sbp_parallel.cfg.h" + +namespace oneflow { + +using user_op::AttrVal; +template +using TypedAttrValRef = user_op::TypedAttrValRef; + +namespace user_op { +class OpKernelState; +} // namespace user_op + +class OpInterpCtx { + public: + explicit OpInterpCtx(const std::shared_ptr& op) : op_(op) {} + virtual ~OpInterpCtx() = default; + + template + Maybe GetAttr(const std::string& attr_name) const; + + Maybe GetAttr(const std::string& attr_name) const; + + OpAttrs GetAttrs() const; + + template + Maybe SetAttr(const std::string& attr_name, const T& attr_val); + + Maybe SetAttr(const std::string& attr_name, const AttrVal& attr_val); + + bool HasAttr(const std::string& attr_name) const; + + const HashSet& AttrNames() const; + + public: + std::shared_ptr op_; + + Optional> device; // for local op + Optional> parallel_desc; // for consistent op + Optional> sbp; // for consistent op + Optional state; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_FRAMEWORK_OP_INTERP_CTX_H_ diff --git a/oneflow/core/framework/system_ops.cpp b/oneflow/core/framework/system_ops.cpp new file mode 100644 index 00000000000..44b449fe5ec --- /dev/null +++ b/oneflow/core/framework/system_ops.cpp @@ -0,0 +1,115 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/system_ops.h" +#include "oneflow/core/framework/attr_value.h" + +namespace oneflow { +namespace schema { + +Maybe CastToConsistentOp::GetAttr(const std::string& attr_name) const { + if (attr_name == "shape") { + return CastAttrValue(&shape); + } else if (attr_name == "dtype") { + return CastAttrValue(&dtype); + } else { + return Error::RuntimeError() << "CastToConsistent op has no attribute named " << attr_name; + } +} + +const HashSet& CastToConsistentOp::AttrNames() const { + static HashSet attr_names{"shape", "dtype"}; + return attr_names; +} + +Maybe SelectTopNOp::GetAttr(const std::string& attr_name) const { + if (attr_name == "top_n") { + return CastAttrValue(&top_n); + } else { + return Error::RuntimeError() << "SelectTopN op has no attribute named " << attr_name; + } +} + +const HashSet& SelectTopNOp::AttrNames() const { + static HashSet attr_names{"top_n"}; + return attr_names; +} + +Maybe FeedInputOp::GetAttr(const std::string& attr_name) const { + return Error::RuntimeError() << "FeedInput op has no attribute named " << attr_name; +} + +Maybe FetchOutputOp::GetAttr(const std::string& attr_name) const { + return Error::RuntimeError() << "FetchOutput op has no attribute named " << attr_name; +} + +Maybe FeedVariableOp::GetAttr(const std::string& attr_name) const { + if (attr_name == "_l2") { + return CastAttrValue(&_l2); + } else { + return Error::RuntimeError() << "FeedVariable op has no attribute named " << attr_name; + } +} + +const HashSet& FeedVariableOp::AttrNames() const { + static HashSet attr_names{"_l2"}; + return attr_names; +} + +Maybe ImageDecoderRandomCropResizeOp::GetAttr(const std::string& attr_name) const { + if (attr_name == "target_width") { + return CastAttrValue(&target_width); + } else if (attr_name == "target_height") { + return CastAttrValue(&target_height); + } else if (attr_name == "num_workers") { + return CastAttrValue(&num_workers); + } else if (attr_name == "max_num_pixels") { + return CastAttrValue(&max_num_pixels); + } else if (attr_name == "warmup_size") { + return CastAttrValue(&warmup_size); + } else if (attr_name == "seed") { + return CastAttrValue(&seed); + } else if (attr_name == "num_attempts") { + return CastAttrValue(&num_attempts); + } else if (attr_name == "random_area_min") { + return CastAttrValue(&random_area_min); + } else if (attr_name == "random_area_max") { + return CastAttrValue(&random_area_max); + } else if (attr_name == "random_aspect_ratio_min") { + return CastAttrValue(&random_aspect_ratio_min); + } else if (attr_name == "random_aspect_ratio_max") { + return CastAttrValue(&random_aspect_ratio_max); + } else { + return Error::RuntimeError() << "FeedVariable op has no attribute named " << attr_name; + } +} + +const HashSet& ImageDecoderRandomCropResizeOp::AttrNames() const { + static HashSet attr_names{"target_width", + "target_height", + "num_workers", + "max_num_pixels", + "warmup_size", + "seed", + "num_attempts", + "random_area_min", + "random_area_max", + "random_aspect_ratio_min", + "random_aspect_ratio_max"}; + return attr_names; +} + +} // namespace schema +} // namespace oneflow diff --git a/oneflow/core/framework/system_ops.h b/oneflow/core/framework/system_ops.h new file mode 100644 index 00000000000..69b1fad6858 --- /dev/null +++ b/oneflow/core/framework/system_ops.h @@ -0,0 +1,89 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_CORE_FRAMEWORK_SYSTEM_OPS_H_ +#define ONEFLOW_CORE_FRAMEWORK_SYSTEM_OPS_H_ + +#include "oneflow/core/framework/op_base.h" + +#include "oneflow/core/common/data_type.pb.h" +#include "oneflow/core/common/hash_container.h" +#include "oneflow/core/common/maybe.h" +#include "oneflow/core/common/shape.h" + +namespace oneflow { +namespace schema { + +class CastToConsistentOp : public OpBase { + public: + Maybe GetAttr(const std::string& attr_name) const override; + const HashSet& AttrNames() const override; + + public: + Shape shape; + DataType dtype; +}; + +class SelectTopNOp : public OpBase { + public: + Maybe GetAttr(const std::string& attr_name) const override; + const HashSet& AttrNames() const override; + + public: + int32_t top_n; +}; + +class FeedInputOp : public OpBase { + public: + Maybe GetAttr(const std::string& attr_name) const override; +}; + +class FetchOutputOp : public OpBase { + public: + Maybe GetAttr(const std::string& attr_name) const override; +}; + +class FeedVariableOp : public OpBase { + public: + Maybe GetAttr(const std::string& attr_name) const override; + const HashSet& AttrNames() const override; + + public: + double _l2; +}; + +class ImageDecoderRandomCropResizeOp : public OpBase { + public: + Maybe GetAttr(const std::string& attr_name) const override; + const HashSet& AttrNames() const override; + + public: + int64_t target_width; + int64_t target_height; + int64_t num_workers; + int64_t max_num_pixels; + int64_t warmup_size; + int64_t seed; + int64_t num_attempts; + float random_area_min; + float random_area_max; + float random_aspect_ratio_min; + float random_aspect_ratio_max; +}; + +} // namespace schema +} // namespace oneflow + +#endif // ONEFLOW_CORE_FRAMEWORK_SYSTEM_OPS_H_ diff --git a/oneflow/core/framework/user_op_conf.cpp b/oneflow/core/framework/user_op_conf.cpp index d34a6700977..0b4ac26e617 100644 --- a/oneflow/core/framework/user_op_conf.cpp +++ b/oneflow/core/framework/user_op_conf.cpp @@ -278,17 +278,11 @@ Maybe CheckArgDefIsValidInUserOpConf( if (arg_name2lbns.find(arg.name()) != arg_name2lbns.end()) { arg_blob_num = arg_name2lbns.at(arg.name()).s_size(); } - if (arg_blob_num != arg.num()) { - if (arg_blob_num == 0) { - CHECK_OR_RETURN(arg.is_optional()) - << " op_name: " << op_name << " op_type_name: " << op_type_name - << " arg name: " << arg.name() << " in OpDef must have blob in op_conf"; - } else { - CHECK_OR_RETURN(arg_blob_num > arg.num() && arg.num_as_min()) - << " op_name: " << op_name << " op_type_name: " << op_type_name - << " arg name: " << arg.name() << " has blob num: " << arg_blob_num - << " in op_conf does not meet its constraints in OpDef"; - } + if (arg_blob_num == 0) { + CHECK_OR_RETURN(arg.is_optional()) + << " op_name: " << op_name << " op_type_name: " << op_type_name + << " arg name: " << arg.name() << " in OpDef must have blob in op_conf: \n" + << op_conf.DebugString(); } op_def_arg_names.insert(arg.name()); } @@ -358,24 +352,6 @@ Maybe AddAttrDefaultValueAndCheckValid(const UserOpDef& op_def, OperatorCo return AddAttrDefaultValueAndCheckValid(op_def, user_conf, error_msg_prefix); } -Maybe AddUserOpConfOutputDefaultArg(const UserOpDef& op_def, OperatorConf* op_conf) { - UserOpConf* user_conf = op_conf->mutable_user_conf(); - // add default output arg and lbn - for (const auto& output_arg : op_def.output()) { - if (user_conf->output().find(output_arg.name()) == user_conf->output().end() - && (!output_arg.is_optional()) && (!output_arg.num_as_min())) { - for (int32_t i = 0; i < output_arg.num(); ++i) { - std::string lbn = GenLogicalBlobName(op_conf->name(), GenRepeatedBn(output_arg.name(), i)); - (*(user_conf->mutable_output()))[output_arg.name()].add_s(lbn); - CHECK_EQ_OR_RETURN(i + 1, user_conf->output().at(output_arg.name()).s_size()); - } - user_conf->add_output_order(output_arg.name()); - CHECK_EQ_OR_RETURN(user_conf->output().size(), user_conf->output_order().size()); - } - } - return Maybe::Ok(); -} - Maybe GetAttrTypeImpl(const std::string& op_type_name, const std::string& attr_name) { const user_op::OpRegistryResult* val = user_op::UserOpRegistryMgr::Get().GetOpRegistryResult(op_type_name); @@ -397,7 +373,6 @@ Maybe CheckAndCompleteUserOpConfImpl(const OperatorConf& op_conf) const UserOpDef& op_def = val->op_def; JUST(AddAttrDefaultValueAndCheckValid(op_def, &ret)); - JUST(AddUserOpConfOutputDefaultArg(op_def, &ret)); // check input and output valid JUST(CheckArgDefIsValidInUserOpConf(op_conf, user_conf->input(), op_def.input())); JUST(CheckArgDefIsValidInUserOpConf(op_conf, user_conf->output(), op_def.output())); diff --git a/oneflow/core/framework/user_op_def.cpp b/oneflow/core/framework/user_op_def.cpp index 8aee596d2c5..def5da20d92 100644 --- a/oneflow/core/framework/user_op_def.cpp +++ b/oneflow/core/framework/user_op_def.cpp @@ -51,12 +51,6 @@ bool UserOpDefWrapper::IsArgOptional(const std::string& name) const { return arg_def->is_optional(); } -std::pair UserOpDefWrapper::ArgNumAndIsMin(const std::string& name) const { - const UserOpDef::ArgDef* arg_def = GetArgPointer(name); - CHECK_NOTNULL(arg_def); - return std::make_pair(arg_def->num(), arg_def->num_as_min()); -} - const UserOpDef::ArgDef* UserOpDefWrapper::GetArgPointer(const std::string& name) const { auto it = inputs_.find(name); if (it != inputs_.end()) { return it->second; } diff --git a/oneflow/core/framework/user_op_def.h b/oneflow/core/framework/user_op_def.h index 8d39a4333cf..f3c2aab548e 100644 --- a/oneflow/core/framework/user_op_def.h +++ b/oneflow/core/framework/user_op_def.h @@ -37,7 +37,6 @@ class UserOpDefWrapper final { bool IsAttrName(const std::string&) const; bool IsArgOptional(const std::string&) const; - std::pair ArgNumAndIsMin(const std::string&) const; AttrType GetAttrType(const std::string&) const; bool AttrHasDefaultVal(const std::string&) const; diff --git a/oneflow/core/framework/user_op_def.proto b/oneflow/core/framework/user_op_def.proto index d22c8c808b6..e4ea72157e3 100644 --- a/oneflow/core/framework/user_op_def.proto +++ b/oneflow/core/framework/user_op_def.proto @@ -9,8 +9,6 @@ message UserOpDef { message ArgDef { required string name = 1; optional bool is_optional = 2 [default = false]; - required int32 num = 3; - required bool num_as_min = 4; } repeated ArgDef input = 2; repeated ArgDef output = 3; diff --git a/oneflow/core/framework/user_op_registry.cpp b/oneflow/core/framework/user_op_registry.cpp index e40bc0bea3f..e763f9f5def 100644 --- a/oneflow/core/framework/user_op_registry.cpp +++ b/oneflow/core/framework/user_op_registry.cpp @@ -42,15 +42,13 @@ OpRegistry& OpRegistry::Name(const std::string& op_type_name) { return *this; } -OpRegistry& OpRegistry::ArgImpl(bool is_input, const std::string& name, bool is_optional, - int32_t num, bool num_as_min) { - CHECK(InsertIfNotExists(name, &unique_names_)); +OpRegistry& OpRegistry::ArgImpl(bool is_input, const std::string& name, bool is_optional) { + CHECK(InsertIfNotExists(name, &unique_names_)) + << "op arg registered, name: " << name << ", op: " << result_.op_type_name; UserOpDef::ArgDef arg_def; { arg_def.set_name(name); arg_def.set_is_optional(is_optional); - arg_def.set_num(num); - arg_def.set_num_as_min(num_as_min); } if (is_input) { *(result_.op_def.mutable_input()->Add()) = arg_def; @@ -60,15 +58,9 @@ OpRegistry& OpRegistry::ArgImpl(bool is_input, const std::string& name, bool is_ return *this; } -#define OP_REG_ARG_MEMBER_FUNC(name_prefix, is_input, is_optional) \ - OpRegistry& OpRegistry::name_prefix(const std::string& name) { \ - return ArgImpl(is_input, name, is_optional, 1, false); \ - } \ - OpRegistry& OpRegistry::name_prefix(const std::string& name, int32_t num) { \ - return ArgImpl(is_input, name, is_optional, num, false); \ - } \ - OpRegistry& OpRegistry::name_prefix##WithMinimum(const std::string& name, int32_t min_num) { \ - return ArgImpl(is_input, name, is_optional, min_num, true); \ +#define OP_REG_ARG_MEMBER_FUNC(name_prefix, is_input, is_optional) \ + OpRegistry& OpRegistry::name_prefix(const std::string& name) { \ + return ArgImpl(is_input, name, is_optional); \ } OP_REG_ARG_MEMBER_FUNC(Input, true, false) diff --git a/oneflow/core/framework/user_op_registry.h b/oneflow/core/framework/user_op_registry.h index fb3d69ecb56..036aa792acd 100644 --- a/oneflow/core/framework/user_op_registry.h +++ b/oneflow/core/framework/user_op_registry.h @@ -48,12 +48,13 @@ using SbpSignatureInferFn = std::function(InferSbpSignatureFnContext using InputArgModifier = InputBlobModifier; using GetInputArgModifier = std::function; -using InputArgModifyFn = std::function(GetInputArgModifier, const UserOpConfWrapper&)>; +using InputArgModifyFn = + std::function(const GetInputArgModifier&, const UserOpConfWrapper&)>; using OutputArgModifier = OutputBlobModifier; using GetOutputArgModifier = std::function; using OutputArgModifyFn = - std::function(GetOutputArgModifier, const UserOpConfWrapper&)>; + std::function(const GetOutputArgModifier&, const UserOpConfWrapper&)>; using OutputBlobTimeShapeInferFn = std::function(InferOutputBlobTimeShapeFnContext*)>; using NdSbpInferFn = std::function(InferNdSbpFnContext*)>; @@ -129,8 +130,7 @@ class OpRegistry final { OpRegistryResult GetResult() { return result_; } private: - OpRegistry& ArgImpl(bool is_input, const std::string& name, bool is_optional, int32_t num, - bool num_as_min); + OpRegistry& ArgImpl(bool is_input, const std::string& name, bool is_optional); OpRegistry& DefaultedAttr(const std::string& name, AttrType type, const std::function& SetDefault); diff --git a/oneflow/core/functional/impl/activation_functor.cpp b/oneflow/core/functional/impl/activation_functor.cpp index 122f0f98d81..26f2ab1c874 100644 --- a/oneflow/core/functional/impl/activation_functor.cpp +++ b/oneflow/core/functional/impl/activation_functor.cpp @@ -36,9 +36,7 @@ namespace impl { class ReluFunctor { public: - ReluFunctor() { - op_ = CHECK_JUST(one::OpBuilder("relu").Input("in", 1).Output("out", 1).Build()); - } + ReluFunctor() { op_ = CHECK_JUST(one::OpBuilder("relu").Input("x", 1).Output("y", 1).Build()); } Maybe operator()(const std::shared_ptr& x, bool inplace) const { if (inplace) { JUST(CheckInplaceValid(x)); diff --git a/oneflow/core/functional/impl/nn_functor.cpp b/oneflow/core/functional/impl/nn_functor.cpp index c4ef99005e6..3af295abc05 100644 --- a/oneflow/core/functional/impl/nn_functor.cpp +++ b/oneflow/core/functional/impl/nn_functor.cpp @@ -1265,7 +1265,7 @@ class NormalizationAddReluFunctor { .Output("y") .Attr("training", false) .Build()); - relu_op_ = CHECK_JUST(one::OpBuilder("relu").Input("in").Output("out").Build()); + relu_op_ = CHECK_JUST(one::OpBuilder("relu").Input("x").Output("y").Build()); add_op_ = CHECK_JUST(one::OpBuilder("add_n").Input("in", 2).Output("out").Build()); fused_norm_training_stats_op_ = CHECK_JUST(one::OpBuilder("normalization_add_relu") .Input("x") @@ -2187,4 +2187,4 @@ ONEFLOW_FUNCTION_LIBRARY(m) { } // namespace functional } // namespace one -} // namespace oneflow \ No newline at end of file +} // namespace oneflow diff --git a/oneflow/core/operator/user_op.cpp b/oneflow/core/operator/user_op.cpp index ea529fc9c8c..8a85910ae2b 100644 --- a/oneflow/core/operator/user_op.cpp +++ b/oneflow/core/operator/user_op.cpp @@ -135,8 +135,9 @@ class UserOpInferContext final : public user_op::InferContext { auto InitTensorDesc = [&](const ArgVec& arg_vec, const PbRpf& bns) { CHECK_EQ(arg_vec.size(), bns.size()); for (int32_t i = 0; i < arg_vec.size(); ++i) { + const auto& bn_i = bns.Get(i); BlobDesc* blob = GetBlobDesc4BnInOp(bns.Get(i)); - CHECK_NOTNULL(blob); + CHECK(blob != nullptr) << bn_i; arg2tensor_desc_.emplace(arg_vec.at(i), GenTensorDescFromBlobDesc(blob)); } }; diff --git a/oneflow/ir/CMakeLists.txt b/oneflow/ir/CMakeLists.txt index fdb9ab2bb82..b0a17da4797 100644 --- a/oneflow/ir/CMakeLists.txt +++ b/oneflow/ir/CMakeLists.txt @@ -30,6 +30,7 @@ else() message(FATAL_ERROR "LLVM_PROVIDER should be in-tree or install, but got: ${LLVM_PROVIDER}") endif() +set_property(GLOBAL PROPERTY LLVM_INSTALL_DIR ${LLVM_INSTALL_DIR}) set(MLIR_TABLEGEN_EXE mlir-tblgen) include_directories(${LLVM_INCLUDE_DIRS}) @@ -43,6 +44,10 @@ include_directories(${PROJECT_BINARY_DIR}/include) link_directories(${LLVM_BUILD_LIBRARY_DIR}) add_definitions(${LLVM_DEFINITIONS}) +if(LLVM_PROVIDER STREQUAL "in-tree") + add_subdirectory(${CMAKE_SOURCE_DIR}/tools/oneflow-tblgen ${PROJECT_BINARY_DIR}/oneflow-tblgen) +endif() + set_property(GLOBAL PROPERTY ALL_ONEFLOW_LIBS -Wl,--no-as-needed oneflow -Wl,--as-needed -Wl,--no-as-needed ${oneflow_exe_third_party_libs} -Wl,--as-needed @@ -74,6 +79,7 @@ set(LLVM_PTHREAD_LIB ${CMAKE_THREAD_LIBS_INIT}) set(LLVM_RUNTIME_OUTPUT_INTDIR ${PROJECT_BINARY_DIR}/bin) set(LLVM_LIBRARY_OUTPUT_INTDIR ${PROJECT_BINARY_DIR}/lib) +if(WITH_MLIR) add_subdirectory(include) add_subdirectory(lib) add_subdirectory(test) @@ -81,3 +87,4 @@ add_subdirectory(oneflow-opt) add_subdirectory(oneflow-translate) add_subdirectory(oneflow-runtime) add_subdirectory(oneflow-extension) +endif(WITH_MLIR) diff --git a/oneflow/ir/include/OneFlow/CMakeLists.txt b/oneflow/ir/include/OneFlow/CMakeLists.txt index 81891cf8d27..8dd61fe2eb1 100644 --- a/oneflow/ir/include/OneFlow/CMakeLists.txt +++ b/oneflow/ir/include/OneFlow/CMakeLists.txt @@ -1,11 +1,5 @@ -set(ONEFLOW_USER_OP_GEN_TD_PATH "${PROJECT_BINARY_DIR}/include/OneFlow") -message(STATUS "Generating user op ODS ${ONEFLOW_USER_OP_GEN_TD_PATH}/OneFlowUserOpGen.td") -add_custom_target(GenUserOpODS - DEPENDS oneflow-gen-ods - COMMAND "$" - BYPRODUCTS OneFlowUserOpGen.td - WORKING_DIRECTORY "${ONEFLOW_USER_OP_GEN_TD_PATH}" -) +# set(ONEFLOW_USER_OP_GEN_TD_PATH "${PROJECT_BINARY_DIR}/include/OneFlow") +set(ONEFLOW_USER_OP_GEN_TD_PATH "${PROJECT_SOURCE_DIR}/include/OneFlow") set(LLVM_TARGET_DEFINITIONS OneFlowEnums.td) mlir_tablegen(OneFlowEnums.h.inc -gen-enum-decls) @@ -19,7 +13,6 @@ foreach (OP_GROUP_NAME IN LISTS ONEFLOW_OP_GROUPS_USED_IN_PATTERNS) endforeach() mlir_tablegen(OneFlowPatterns.cpp.inc -gen-rewriters) add_public_tablegen_target(MLIROneFlowPatternsIncGen) -add_dependencies(MLIROneFlowPatternsIncGen GenUserOpODS) # NOTE: seperate conversion and opt with --name set(LLVM_TARGET_DEFINITIONS OneFlowOps.td) @@ -44,12 +37,10 @@ foreach (OP_GROUP_NAME IN LISTS ONEFLOW_OP_GROUPS) mlir_tablegen(${HEADER_INC_FILE} -gen-op-decls) endforeach() add_public_tablegen_target(MLIROneFlowOpGroupDefsIncGen) -add_dependencies(MLIROneFlowOpGroupDefsIncGen GenUserOpODS) set(LLVM_TABLEGEN_FLAGS "${FULL_LLVM_TABLEGEN_FLAGS}") mlir_tablegen(OneFlow.gen_ops.h.inc -gen-op-decls) add_public_tablegen_target(MLIROneFlowOpGroupDeclsIncGen) -add_dependencies(MLIROneFlowOpGroupDeclsIncGen GenUserOpODS) set(LLVM_TABLEGEN_FLAGS "") add_mlir_dialect( diff --git a/oneflow/ir/include/OneFlow/OneFlowBase.td b/oneflow/ir/include/OneFlow/OneFlowBase.td index 5591d83ed31..36d2b0cba89 100644 --- a/oneflow/ir/include/OneFlow/OneFlowBase.td +++ b/oneflow/ir/include/OneFlow/OneFlowBase.td @@ -33,8 +33,8 @@ class OneFlow_BaseOp traits = []> : dag attrs = (ins); dag trait_attrs = (ins); dag user_op_attrs = (ins); - dag input = (ins Variadic:$data_input); - dag output = (outs Variadic:$data_output); + dag input = (ins); + dag output = (outs); dag ctrl_input = (ins); dag ctrl_output = (outs); let arguments = !con( @@ -49,6 +49,19 @@ class OneFlow_BaseOp traits = []> : output, ctrl_output ); + int same_output_regst_num = -1; + + bit has_check_fn = 0; + bit has_logical_tensor_desc_infer_fn = 0; + bit has_physical_tensor_desc_infer_fn = 0; + bit has_get_sbp_fn = 0; + bit has_sbp_signature_infer_fn = 0; + bit has_data_type_infer_fn = 0; + bit has_device_infer_fn = 0; + bit has_input_arg_modify_fn = 0; + bit has_output_arg_modify_fn = 0; + bit has_output_blob_time_shape_infer_fn = 0; + bit has_nd_sbp_infer_fn = 0; } class OneFlow_Op traits = []> : @@ -102,17 +115,22 @@ class OneFlow_ConvolutionBaseOp traits = []> : ); let output = (outs AnyType:$out); let attrs = (ins - SI32Attr:$filters, + DefaultValuedAttr:$filters, SI32ArrayAttr:$padding_before, StrAttr:$data_format, SI32ArrayAttr:$kernel_size, SI32ArrayAttr:$strides, SI32ArrayAttr:$dilation_rate, - DefaultValuedAttr:$group + DefaultValuedAttr:$groups ); let trait_attrs = (ins I32ElementsAttr:$operand_segment_sizes ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; } class OneFlow_TFPoolBaseOp traits = []> : @@ -129,6 +147,10 @@ class OneFlow_TFPoolBaseOp traits = []> : SI32ArrayAttr:$strides, BoolAttr:$ceil_mode ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; } class OneFlow_TFPoolGradBaseOp traits = []> : @@ -149,6 +171,10 @@ class OneFlow_TFPoolGradBaseOp traits = []> : SI32ArrayAttr:$strides, BoolAttr:$ceil_mode ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; } @@ -171,6 +197,10 @@ class OneFlow_MaxPoolBaseOp traits = []> : DefaultValuedAttr:$return_indices, DefaultValuedAttr:$ceil_mode ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; } class OneFlow_AvgPoolBaseOp traits = []> : @@ -191,6 +221,10 @@ class OneFlow_AvgPoolBaseOp traits = []> : DefaultValuedAttr:$count_include_pad, DefaultValuedAttr:$divisor_override ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; } class OneFlow_MaxPoolGradBaseOp traits = []> : @@ -214,6 +248,10 @@ class OneFlow_MaxPoolGradBaseOp traits = []> : DefaultValuedAttr:$return_indices, DefaultValuedAttr:$ceil_mode ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; } class OneFlow_AvgPoolGradBaseOp traits = []> : @@ -236,6 +274,10 @@ class OneFlow_AvgPoolGradBaseOp traits = []> : DefaultValuedAttr:$count_include_pad, DefaultValuedAttr:$divisor_override ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; } class OneFlow_AdaptivePoolBaseOp traits = []> : @@ -248,6 +290,10 @@ class OneFlow_AdaptivePoolBaseOp traits = []> : let attrs = (ins SI64ArrayAttr:$output_size ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; } class OneFlow_AdaptivePoolGradBaseOp traits = []> : @@ -261,6 +307,10 @@ class OneFlow_AdaptivePoolGradBaseOp traits = []> let attrs = (ins SI64ArrayAttr:$output_size ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; } class OneFlow_UnaryBaseOp traits = []> : @@ -268,21 +318,23 @@ class OneFlow_UnaryBaseOp traits = []> : let summary = ""; let input = (ins AnyType:$x); let output = (outs AnyType:$y); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; } def OneFlow_Idempotent : NativeOpTrait<"IsIdempotentOfIdenticalPlacement">; class OneFlow_IdempotentBaseOp traits = []> : - OneFlow_UnaryBaseOp { -} + OneFlow_UnaryBaseOp {} def OneFlow_Involution : NativeOpTrait<"IsInvolutionOfIdenticalPlacement">; class OneFlow_InvolutionBaseOp traits = []> : - OneFlow_UnaryBaseOp { -} + OneFlow_UnaryBaseOp {} #define GET_ONEFLOW_BASE_OP_DEFINITIONS -include "OneFlow/OneFlowUserOpGen.td" +include "OneFlow/OneFlowUserOps.td" #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWBASE_H_ diff --git a/oneflow/ir/include/OneFlow/OneFlowInterfaces.td b/oneflow/ir/include/OneFlow/OneFlowInterfaces.td index 85e8a6f6e6e..1aa344884bd 100644 --- a/oneflow/ir/include/OneFlow/OneFlowInterfaces.td +++ b/oneflow/ir/include/OneFlow/OneFlowInterfaces.td @@ -73,4 +73,15 @@ def ControlEdgeCompatibleInterface : OpInterface<"ControlEdgeCompatible"> { ]; } +def NoGrad : OpInterface<"NoGrad"> { + let description = [{ + }]; +} + +def CpuOnly : OpInterface<"CpuOnly"> { + let description = [{ + }]; +} + + #endif // ONEFLOW_IR_INCLUDE_ONEFLOW_ONEFLOWINTERFACES_H_ diff --git a/oneflow/ir/include/OneFlow/OneFlowOps.td b/oneflow/ir/include/OneFlow/OneFlowOps.td index 16194c28d59..be1433b77db 100644 --- a/oneflow/ir/include/OneFlow/OneFlowOps.td +++ b/oneflow/ir/include/OneFlow/OneFlowOps.td @@ -11,23 +11,6 @@ include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Pass/PassBase.td" -def OneFlow_UserOp : OneFlow_UserBaseWithCtrlOp<"user", [OneFlow_IsImportCompatible]> { - let summary = ""; - let attrs = (ins - StrArrayAttr:$output_lbns - ); - let hasCanonicalizer = 1; -} - -def OneFlow_SystemOp : OneFlow_Op<"system", [OneFlow_IsImportCompatible]> { - let summary = ""; - let attrs = (ins - StrArrayAttr:$input_bns, - StrArrayAttr:$output_lbns - ); - let hasCanonicalizer = 1; -} - def OneFlow_NormalizationAddReluOp : OneFlow_NormalizationAddReluBaseOp { let builders = [ OpBuilder<(ins @@ -52,6 +35,29 @@ def OneFlow_NormalizationAddReluOp : OneFlow_NormalizationAddReluBaseOp { ]; } +#ifndef REMOVE_ONEFLOW_MLIR_ONLY_OP_DEFINITIONS + +def OneFlow_UserOp : OneFlow_UserBaseWithCtrlOp<"user", [OneFlow_IsImportCompatible]> { + let summary = ""; + let input = (ins Variadic:$data_input); + let output = (outs Variadic:$data_output); + let attrs = (ins + StrArrayAttr:$output_lbns + ); + let hasCanonicalizer = 1; +} + +def OneFlow_SystemOp : OneFlow_Op<"system", [OneFlow_IsImportCompatible]> { + let summary = ""; + let input = (ins Variadic:$data_input); + let output = (outs Variadic:$data_output); + let attrs = (ins + StrArrayAttr:$input_bns, + StrArrayAttr:$output_lbns + ); + let hasCanonicalizer = 1; +} + def OneFlow_Add2Op : OneFlow_BaseOp<"add_n2", [NoSideEffect, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { let summary = ""; let input = (ins @@ -64,6 +70,8 @@ def OneFlow_Add2Op : OneFlow_BaseOp<"add_n2", [NoSideEffect, DeclareOpInterfaceM // JIT ops def OneFlow_MlirJitOp : OneFlow_BaseOp<"mlir_jit", [ CallOpInterface, DeclareOpInterfaceMethods ] > { + let input = (ins Variadic:$data_input); + let output = (outs Variadic:$data_output); let attrs = (ins FlatSymbolRefAttr:$callee, StrAttr:$mlir_assembly @@ -224,6 +232,8 @@ def OneFlow_ReturnOp : Op]> { + let input = (ins + AnyType:$ref, + AnyType:$value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_AssignIfOp : OneFlow_BaseOp<"assign_if", [NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$ref, + AnyType:$value, + AnyType:$condition + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_AssignIfNotOp : OneFlow_BaseOp<"assign_if_not", [NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$ref, + AnyType:$value, + AnyType:$condition + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_LogicalSliceAssignOp : OneFlow_BaseOp<"logical_slice_assign", [DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$ref, + AnyType:$value + ); + let attrs = (ins + SI64ArrayAttr:$start, + SI64ArrayAttr:$stop, + SI64ArrayAttr:$step + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +#endif // GET_ONEFLOW_ASSIGN_OP_DEFINITIONS + +// Group: BASE +// normalization_add_relu +// Total: 1 + +#ifdef GET_ONEFLOW_BASE_OP_DEFINITIONS + +class OneFlow_NormalizationAddReluBaseOp : OneFlow_BaseOp<"normalization_add_relu", [NoSideEffect, AttrSizedOperandSegments, AttrSizedResultSegments, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + Optional:$addend, + Optional:$moving_mean, + Optional:$moving_variance, + AnyType:$gamma, + AnyType:$beta + ); + let output = (outs + AnyType:$y, + AnyType:$reserve_space, + Optional:$mean, + Optional:$inv_variance + ); + let attrs = (ins + DefaultValuedAttr:$axis, + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$training, + DefaultValuedAttr:$momentum + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes, + I32ElementsAttr:$result_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +#endif // GET_ONEFLOW_BASE_OP_DEFINITIONS + +// Group: BINARY +// bias_add, cast_like, celu_grad, diag_grad, diagonal_grad, dot, dropout_grad, elementwise_maximum, elementwise_minimum, elu_grad, floordiv, gelu_grad, grid_sample, hardsigmoid_grad, hardswish_grad, l1_l2_regularize_gradient, leaky_relu_grad, masked_fill, mish_grad, multiply, narrow_grad, pow, prelu, relu_grad, selu_grad, sigmoid_grad, silu_grad, tf_prelu, unfold_tensor_grad, xdivy, xlogy +// Total: 31 + +#ifdef GET_ONEFLOW_BINARY_OP_DEFINITIONS + +def OneFlow_BiasAddOp : OneFlow_BaseOp<"bias_add", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$a, + AnyType:$b + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CastLikeOp : OneFlow_BaseOp<"cast_like", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in, + AnyType:$dtype_like + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_CeluGradOp : OneFlow_BaseOp<"celu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$alpha + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_DiagGradOp : OneFlow_BaseOp<"diag_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$in + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$diagonal + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_DiagonalGradOp : OneFlow_BaseOp<"diagonal_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$in + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$offset + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_DotOp : OneFlow_BaseOp<"dot", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_DropoutGradOp : OneFlow_BaseOp<"dropout_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$mask + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$scale + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ElementwiseMaximumOp : OneFlow_BaseOp<"elementwise_maximum", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ElementwiseMinimumOp : OneFlow_BaseOp<"elementwise_minimum", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_EluGradOp : OneFlow_BaseOp<"elu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$alpha + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FloordivOp : OneFlow_BaseOp<"floordiv", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_GeluGradOp : OneFlow_BaseOp<"gelu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_GridSampleOp : OneFlow_BaseOp<"grid_sample", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input, + AnyType:$grid + ); + let output = (outs + AnyType:$output + ); + let attrs = (ins + StrAttr:$interpolation_mode, + StrAttr:$padding_mode, + DefaultValuedAttr:$align_corners + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_HardsigmoidGradOp : OneFlow_BaseOp<"hardsigmoid_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_HardswishGradOp : OneFlow_BaseOp<"hardswish_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_L1L2RegularizeGradientOp : OneFlow_BaseOp<"l1_l2_regularize_gradient", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$model, + AnyType:$model_diff + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$l1, + DefaultValuedAttr:$l2 + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LeakyReluGradOp : OneFlow_BaseOp<"leaky_relu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$alpha + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_MaskedFillOp : OneFlow_BaseOp<"masked_fill", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$mask + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_MishGradOp : OneFlow_BaseOp<"mish_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_MultiplyOp : OneFlow_BaseOp<"multiply", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_NarrowGradOp : OneFlow_BaseOp<"narrow_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$like + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$dim, + DefaultValuedAttr:$start, + DefaultValuedAttr:$length + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_PowOp : OneFlow_BaseOp<"pow", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_PreluOp : OneFlow_BaseOp<"prelu", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$alpha + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReluGradOp : OneFlow_BaseOp<"relu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$y, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SeluGradOp : OneFlow_BaseOp<"selu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SigmoidGradOp : OneFlow_BaseOp<"sigmoid_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$y, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SiluGradOp : OneFlow_BaseOp<"silu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TfPreluOp : OneFlow_BaseOp<"tf_prelu", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$alpha + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UnfoldTensorGradOp : OneFlow_BaseOp<"unfold_tensor_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$x + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$dimension, + DefaultValuedAttr:$size, + DefaultValuedAttr:$step + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_XdivyOp : OneFlow_BaseOp<"xdivy", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_XlogyOp : OneFlow_BaseOp<"xlogy", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_BINARY_OP_DEFINITIONS + +// Group: BROADCAST +// broadcast_add, broadcast_div, broadcast_div_grad, broadcast_equal, broadcast_floor_mod, broadcast_fmod, broadcast_greater, broadcast_greater_equal, broadcast_less, broadcast_less_equal, broadcast_like, broadcast_logical_and, broadcast_logical_or, broadcast_logical_xor, broadcast_maximum, broadcast_minimum, broadcast_mul, broadcast_not_equal, broadcast_pow, broadcast_pow_x_grad, broadcast_pow_y_grad, broadcast_sub +// Total: 22 + +#ifdef GET_ONEFLOW_BROADCAST_OP_DEFINITIONS + +def OneFlow_BroadcastAddOp : OneFlow_BaseOp<"broadcast_add", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastDivOp : OneFlow_BaseOp<"broadcast_div", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastDivGradOp : OneFlow_BaseOp<"broadcast_div_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$y, + AnyType:$z, + AnyType:$dz + ); + let output = (outs + AnyType:$dy + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastEqualOp : OneFlow_BaseOp<"broadcast_equal", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastFloorModOp : OneFlow_BaseOp<"broadcast_floor_mod", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastFmodOp : OneFlow_BaseOp<"broadcast_fmod", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastGreaterOp : OneFlow_BaseOp<"broadcast_greater", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastGreaterEqualOp : OneFlow_BaseOp<"broadcast_greater_equal", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastLessOp : OneFlow_BaseOp<"broadcast_less", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastLessEqualOp : OneFlow_BaseOp<"broadcast_less_equal", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastLikeOp : OneFlow_BaseOp<"broadcast_like", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$like + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + SI32ArrayAttr:$broadcast_axes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_BroadcastLogicalAndOp : OneFlow_BaseOp<"broadcast_logical_and", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastLogicalOrOp : OneFlow_BaseOp<"broadcast_logical_or", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastLogicalXorOp : OneFlow_BaseOp<"broadcast_logical_xor", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastMaximumOp : OneFlow_BaseOp<"broadcast_maximum", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastMinimumOp : OneFlow_BaseOp<"broadcast_minimum", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastMulOp : OneFlow_BaseOp<"broadcast_mul", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastNotEqualOp : OneFlow_BaseOp<"broadcast_not_equal", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastPowOp : OneFlow_BaseOp<"broadcast_pow", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastPowXGradOp : OneFlow_BaseOp<"broadcast_pow_x_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y, + AnyType:$z, + AnyType:$dz + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastPowYGradOp : OneFlow_BaseOp<"broadcast_pow_y_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y, + AnyType:$z, + AnyType:$dz + ); + let output = (outs + AnyType:$dy + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastSubOp : OneFlow_BaseOp<"broadcast_sub", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_BROADCAST_OP_DEFINITIONS + +// Group: CONV +// conv1d, conv2d, conv3d, conv_bias_grad, conv_data_grad, conv_filter_grad, deconv1d, deconv2d, deconv3d +// Total: 9 + +#ifdef GET_ONEFLOW_CONV_OP_DEFINITIONS + +def OneFlow_Conv1DOp : OneFlow_ConvolutionBaseOp<"conv1d", [NoSideEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> {} + +def OneFlow_Conv2DOp : OneFlow_ConvolutionBaseOp<"conv2d", [NoSideEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> {} + +def OneFlow_Conv3DOp : OneFlow_ConvolutionBaseOp<"conv3d", [NoSideEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> {} + +def OneFlow_ConvBiasGradOp : OneFlow_BaseOp<"conv_bias_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy + ); + let output = (outs + AnyType:$bias_diff + ); + let attrs = (ins + StrAttr:$data_format, + DefaultValuedAttr:$num_spatial_dims + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ConvDataGradOp : OneFlow_BaseOp<"conv_data_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$filter, + AnyType:$x_like, + Optional:$_add_to_output + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$num_spatial_dims, + SI32ArrayAttr:$padding_before, + StrAttr:$data_format, + SI32ArrayAttr:$kernel_size, + SI32ArrayAttr:$strides, + SI32ArrayAttr:$dilation_rate, + DefaultValuedAttr:$groups + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ConvFilterGradOp : OneFlow_BaseOp<"conv_filter_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$x + ); + let output = (outs + AnyType:$filter_diff + ); + let attrs = (ins + DefaultValuedAttr:$num_spatial_dims, + SI32ArrayAttr:$padding_before, + StrAttr:$data_format, + SI32ArrayAttr:$kernel_size, + SI32ArrayAttr:$strides, + SI32ArrayAttr:$dilation_rate, + DefaultValuedAttr:$groups + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_Deconv1DOp : OneFlow_BaseOp<"deconv1d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in, + AnyType:$weight + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$filters, + SI32ArrayAttr:$padding_before, + StrAttr:$data_format, + SI32ArrayAttr:$kernel_size, + SI32ArrayAttr:$output_padding, + SI32ArrayAttr:$strides, + SI32ArrayAttr:$dilation_rate, + DefaultValuedAttr:$groups + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_Deconv2DOp : OneFlow_BaseOp<"deconv2d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in, + AnyType:$weight + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$filters, + SI32ArrayAttr:$padding_before, + StrAttr:$data_format, + SI32ArrayAttr:$kernel_size, + SI32ArrayAttr:$output_padding, + SI32ArrayAttr:$strides, + SI32ArrayAttr:$dilation_rate, + DefaultValuedAttr:$groups + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_Deconv3DOp : OneFlow_BaseOp<"deconv3d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in, + AnyType:$weight + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$filters, + SI32ArrayAttr:$padding_before, + StrAttr:$data_format, + SI32ArrayAttr:$kernel_size, + SI32ArrayAttr:$output_padding, + SI32ArrayAttr:$strides, + SI32ArrayAttr:$dilation_rate, + DefaultValuedAttr:$groups + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_CONV_OP_DEFINITIONS + +// Group: CROSS_ENTROPY +// binary_cross_entropy, binary_cross_entropy_grad, binary_cross_entropy_with_logits, binary_cross_entropy_with_logits_grad, sigmoid_cross_entropy, sigmoid_cross_entropy_grad, sparse_cross_entropy, sparse_cross_entropy_grad, sparse_cross_entropy_ms, sparse_cross_entropy_ms_grad +// Total: 10 + +#ifdef GET_ONEFLOW_CROSS_ENTROPY_OP_DEFINITIONS + +def OneFlow_BinaryCrossEntropyOp : OneFlow_BaseOp<"binary_cross_entropy", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input, + AnyType:$target, + Optional:$weight + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_BinaryCrossEntropyGradOp : OneFlow_BaseOp<"binary_cross_entropy_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input, + AnyType:$target, + Optional:$weight, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BinaryCrossEntropyWithLogitsOp : OneFlow_BaseOp<"binary_cross_entropy_with_logits", [NoSideEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input, + AnyType:$target, + Optional:$weight, + Optional:$pos_weight + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_pos_weight + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_BinaryCrossEntropyWithLogitsGradOp : OneFlow_BaseOp<"binary_cross_entropy_with_logits_grad", [NoSideEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input, + AnyType:$target, + Optional:$weight, + Optional:$pos_weight, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$has_pos_weight + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SigmoidCrossEntropyOp : OneFlow_BaseOp<"sigmoid_cross_entropy", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$prediction, + AnyType:$label + ); + let output = (outs + AnyType:$loss + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_SigmoidCrossEntropyGradOp : OneFlow_BaseOp<"sigmoid_cross_entropy_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$prediction, + AnyType:$loss_diff, + AnyType:$label + ); + let output = (outs + AnyType:$prediction_diff + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_SparseCrossEntropyOp : OneFlow_BaseOp<"sparse_cross_entropy", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$prediction, + AnyType:$label + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$depth + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_SparseCrossEntropyGradOp : OneFlow_BaseOp<"sparse_cross_entropy_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$prediction, + AnyType:$label, + AnyType:$dy + ); + let output = (outs + AnyType:$prediction_diff + ); + let attrs = (ins + DefaultValuedAttr:$depth + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SparseCrossEntropyMsOp : OneFlow_BaseOp<"sparse_cross_entropy_ms", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$prediction, + AnyType:$label + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$depth + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_SparseCrossEntropyMsGradOp : OneFlow_BaseOp<"sparse_cross_entropy_ms_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$prediction, + AnyType:$label, + AnyType:$dy + ); + let output = (outs + AnyType:$prediction_diff + ); + let attrs = (ins + DefaultValuedAttr:$depth + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_CROSS_ENTROPY_OP_DEFINITIONS + +// Group: CUDA +// nvtx_end, nvtx_start +// Total: 2 + +#ifdef GET_ONEFLOW_CUDA_OP_DEFINITIONS + +def OneFlow_NvtxEndOp : OneFlow_BaseOp<"nvtx_end", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + StrAttr:$mark_prefix + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_NvtxStartOp : OneFlow_BaseOp<"nvtx_start", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + StrAttr:$mark_prefix + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_CUDA_OP_DEFINITIONS + +// Group: DATASET +// COCOReader, OFRecordReader, OneRecReader, ctc_greedy_decoder, megatron_gpt_mmap_data_loader, ofrecord_bytes_decoder, ofrecord_image_classification_reader, ofrecord_image_decoder, ofrecord_image_decoder_random_crop, ofrecord_raw_decoder, onerec_decoder +// Total: 11 + +#ifdef GET_ONEFLOW_DATASET_OP_DEFINITIONS + +def OneFlow_COCOReaderOp : OneFlow_BaseOp<"COCOReader", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let output = (outs + AnyType:$image, + AnyType:$image_id, + AnyType:$image_size, + AnyType:$gt_bbox, + AnyType:$gt_label, + AnyType:$gt_segm, + AnyType:$gt_segm_index + ); + let attrs = (ins + DefaultValuedAttr:$session_id, + StrAttr:$annotation_file, + StrAttr:$image_dir, + DefaultValuedAttr:$batch_size, + DefaultValuedAttr:$shuffle_after_epoch, + DefaultValuedAttr:$random_seed, + DefaultValuedAttr:$group_by_ratio, + DefaultValuedAttr:$remove_images_without_annotations, + DefaultValuedAttr:$stride_partition, + StrArrayAttr:$nd_sbp + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_output_arg_modify_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_OFRecordReaderOp : OneFlow_BaseOp<"OFRecordReader", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let output = (outs + AnyType:$out + ); + let attrs = (ins + StrAttr:$data_dir, + DefaultValuedAttr:$data_part_num, + DefaultValuedAttr:$batch_size, + DefaultValuedAttr:$part_name_prefix, + DefaultValuedAttr:$part_name_suffix_length, + DefaultValuedAttr:$random_shuffle, + DefaultValuedAttr:$seed, + DefaultValuedAttr:$shuffle_buffer_size, + DefaultValuedAttr:$shuffle_after_epoch, + StrArrayAttr:$nd_sbp + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_output_arg_modify_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_OneRecReaderOp : OneFlow_BaseOp<"OneRecReader", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let output = (outs + AnyType:$out + ); + let attrs = (ins + StrArrayAttr:$files, + DefaultValuedAttr:$batch_size, + DefaultValuedAttr:$random_shuffle, + DefaultValuedAttr:$shuffle_mode, + DefaultValuedAttr:$seed, + DefaultValuedAttr:$shuffle_buffer_size, + DefaultValuedAttr:$shuffle_after_epoch, + DefaultValuedAttr:$verify_example + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CtcGreedyDecoderOp : OneFlow_BaseOp<"ctc_greedy_decoder", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$log_probs, + AnyType:$input_lengths + ); + let output = (outs + AnyType:$decoded, + AnyType:$neg_sum_logits + ); + let attrs = (ins + DefaultValuedAttr:$merge_repeated + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_MegatronGptMmapDataLoaderOp : OneFlow_BaseOp<"megatron_gpt_mmap_data_loader", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + Optional:$iteration + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + StrAttr:$data_file_prefix, + DefaultValuedAttr:$seq_length, + DefaultValuedAttr:$label_length, + DefaultValuedAttr:$num_samples, + DefaultValuedAttr:$batch_size, + OneFlow_DataType:$dtype, + SI64ArrayAttr:$split_sizes, + DefaultValuedAttr:$split_index, + DefaultValuedAttr:$shuffle, + DefaultValuedAttr:$random_seed, + StrArrayAttr:$nd_sbp + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_OfrecordBytesDecoderOp : OneFlow_BaseOp<"ofrecord_bytes_decoder", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + StrAttr:$name + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_OfrecordImageClassificationReaderOp : OneFlow_BaseOp<"ofrecord_image_classification_reader", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let output = (outs + AnyType:$image, + AnyType:$label + ); + let attrs = (ins + StrAttr:$data_dir, + DefaultValuedAttr:$data_part_num, + DefaultValuedAttr:$batch_size, + DefaultValuedAttr:$part_name_prefix, + DefaultValuedAttr:$part_name_suffix_length, + DefaultValuedAttr:$random_shuffle, + DefaultValuedAttr:$seed, + DefaultValuedAttr:$shuffle_buffer_size, + DefaultValuedAttr:$shuffle_after_epoch, + DefaultValuedAttr:$color_space, + DefaultValuedAttr:$image_feature_name, + DefaultValuedAttr:$label_feature_name, + DefaultValuedAttr:$decode_buffer_size_per_thread, + DefaultValuedAttr:$num_decode_threads_per_machine + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_output_arg_modify_fn = 1; +} + +def OneFlow_OfrecordImageDecoderOp : OneFlow_BaseOp<"ofrecord_image_decoder", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + StrAttr:$name, + DefaultValuedAttr:$color_space + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_OfrecordImageDecoderRandomCropOp : OneFlow_BaseOp<"ofrecord_image_decoder_random_crop", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + StrAttr:$name, + DefaultValuedAttr:$color_space, + DefaultValuedAttr:$num_attempts, + DefaultValuedAttr:$seed, + DefaultValuedAttr:$has_seed, + F32ArrayAttr:$random_area, + F32ArrayAttr:$random_aspect_ratio + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_OfrecordRawDecoderOp : OneFlow_BaseOp<"ofrecord_raw_decoder", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + StrAttr:$name, + ShapeAttr:$shape, + OneFlow_DataType:$data_type, + DefaultValuedAttr:$dim1_varying_length, + DefaultValuedAttr:$truncate + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_OnerecDecoderOp : OneFlow_BaseOp<"onerec_decoder", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + StrAttr:$key, + OneFlow_DataType:$data_type, + ShapeAttr:$static_shape, + DefaultValuedAttr:$is_dynamic, + DefaultValuedAttr:$has_reshape, + ShapeAttr:$reshape, + DefaultValuedAttr:$has_batch_padding, + ShapeAttr:$batch_padding + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; + let has_output_arg_modify_fn = 1; +} + +#endif // GET_ONEFLOW_DATASET_OP_DEFINITIONS + +// Group: DETECTION +// in_top_k, nms, object_bbox_flip, object_bbox_scale, object_segmentation_polygon_flip, object_segmentation_polygon_scale, object_segmentation_polygon_to_mask, roi_align, roi_align_grad, top_k +// Total: 10 + +#ifdef GET_ONEFLOW_DETECTION_OP_DEFINITIONS + +def OneFlow_InTopKOp : OneFlow_BaseOp<"in_top_k", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$targets, + AnyType:$predictions + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$k + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_NmsOp : OneFlow_BaseOp<"nms", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$iou_threshold, + DefaultValuedAttr:$keep_n + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ObjectBboxFlipOp : OneFlow_BaseOp<"object_bbox_flip", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$bbox, + AnyType:$image_size, + AnyType:$flip_code + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ObjectBboxScaleOp : OneFlow_BaseOp<"object_bbox_scale", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$bbox, + AnyType:$scale + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ObjectSegmentationPolygonFlipOp : OneFlow_BaseOp<"object_segmentation_polygon_flip", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$poly, + AnyType:$image_size, + AnyType:$flip_code + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ObjectSegmentationPolygonScaleOp : OneFlow_BaseOp<"object_segmentation_polygon_scale", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$poly, + AnyType:$scale + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ObjectSegmentationPolygonToMaskOp : OneFlow_BaseOp<"object_segmentation_polygon_to_mask", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$poly, + AnyType:$poly_index, + AnyType:$image_size + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_RoiAlignOp : OneFlow_BaseOp<"roi_align", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$rois + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + DefaultValuedAttr:$pooled_h, + DefaultValuedAttr:$pooled_w, + DefaultValuedAttr:$spatial_scale, + DefaultValuedAttr:$sampling_ratio, + DefaultValuedAttr:$aligned + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_RoiAlignGradOp : OneFlow_BaseOp<"roi_align_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$x_like, + AnyType:$rois + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$pooled_h, + DefaultValuedAttr:$pooled_w, + DefaultValuedAttr:$spatial_scale, + DefaultValuedAttr:$sampling_ratio, + DefaultValuedAttr:$aligned + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TopKOp : OneFlow_BaseOp<"top_k", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$k, + DefaultValuedAttr:$sorted + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_DETECTION_OP_DEFINITIONS + +// Group: EAGER +// eager_b_to_s, eager_naive_s_to_s, eager_nccl_all_gather, eager_nccl_all_reduce, eager_nccl_broadcast, eager_nccl_reduce, eager_nccl_reduce_scatter, eager_nccl_s2s, eager_p_to_b, eager_p_to_s, eager_s_to_b, eager_symmetric_s_to_p +// Total: 12 + +#ifdef GET_ONEFLOW_EAGER_OP_DEFINITIONS + +def OneFlow_EagerBToSOp : OneFlow_BaseOp<"eager_b_to_s", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$out_split_axis, + StrAttr:$in_parallel_conf, + StrAttr:$out_parallel_conf, + ShapeAttr:$shape + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_EagerNaiveSToSOp : OneFlow_BaseOp<"eager_naive_s_to_s", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$in_split_axis, + DefaultValuedAttr:$out_split_axis, + StrAttr:$in_parallel_conf, + StrAttr:$out_parallel_conf, + ShapeAttr:$shape + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_EagerNcclAllGatherOp : OneFlow_BaseOp<"eager_nccl_all_gather", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + StrAttr:$parallel_conf + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_EagerNcclAllReduceOp : OneFlow_BaseOp<"eager_nccl_all_reduce", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + StrAttr:$parallel_conf, + DefaultValuedAttr:$async_launch + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; +} + +def OneFlow_EagerNcclBroadcastOp : OneFlow_BaseOp<"eager_nccl_broadcast", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + StrAttr:$parallel_conf, + DefaultValuedAttr:$root + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; +} + +def OneFlow_EagerNcclReduceOp : OneFlow_BaseOp<"eager_nccl_reduce", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + StrAttr:$parallel_conf, + DefaultValuedAttr:$root + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; +} + +def OneFlow_EagerNcclReduceScatterOp : OneFlow_BaseOp<"eager_nccl_reduce_scatter", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + StrAttr:$parallel_conf, + DefaultValuedAttr:$op_type + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_EagerNcclS2sOp : OneFlow_BaseOp<"eager_nccl_s2s", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$in_split_axis, + DefaultValuedAttr:$out_split_axis, + StrAttr:$parallel_conf + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_EagerPToBOp : OneFlow_BaseOp<"eager_p_to_b", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + StrAttr:$in_parallel_conf, + StrAttr:$out_parallel_conf, + ShapeAttr:$shape + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_EagerPToSOp : OneFlow_BaseOp<"eager_p_to_s", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$out_split_axis, + StrAttr:$in_parallel_conf, + StrAttr:$out_parallel_conf, + ShapeAttr:$shape + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_EagerSToBOp : OneFlow_BaseOp<"eager_s_to_b", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$in_split_axis, + StrAttr:$in_parallel_conf, + StrAttr:$out_parallel_conf, + ShapeAttr:$shape + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_EagerSymmetricSToPOp : OneFlow_BaseOp<"eager_symmetric_s_to_p", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$in_split_axis, + StrAttr:$parallel_conf + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +#endif // GET_ONEFLOW_EAGER_OP_DEFINITIONS + +// Group: FUSED +// cudnn_fused_normalization_add_relu, cudnn_fused_normalization_add_relu_grad, fused_bias_add_gelu, fused_bias_add_gelu_grad, fused_bias_add_mask_scale, fused_cast_scale, fused_scale_mask_softmax, fused_scale_mask_softmax_dropout, fused_scale_mask_softmax_dropout_grad, fused_scale_mask_softmax_grad, fused_scale_tril, fused_self_attention_query_mul_key_and_value, fused_self_attention_query_mul_key_and_value_grad, fused_tril_scale_softmax_mask_scale, fused_tril_scale_softmax_mask_scale_grad, normalization_add_relu_grad +// Total: 16 + +#ifdef GET_ONEFLOW_FUSED_OP_DEFINITIONS + +def OneFlow_CudnnFusedNormalizationAddReluOp : OneFlow_BaseOp<"cudnn_fused_normalization_add_relu", [NoSideEffect, AttrSizedOperandSegments, AttrSizedResultSegments, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + Optional:$addend, + Optional:$moving_mean, + Optional:$moving_variance, + AnyType:$gamma, + AnyType:$beta + ); + let output = (outs + AnyType:$y, + AnyType:$reserve_space, + Optional:$mean, + Optional:$inv_variance + ); + let attrs = (ins + DefaultValuedAttr:$axis, + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$momentum + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes, + I32ElementsAttr:$result_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_CudnnFusedNormalizationAddReluGradOp : OneFlow_BaseOp<"cudnn_fused_normalization_add_relu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy, + AnyType:$mean, + AnyType:$inv_variance, + AnyType:$gamma, + AnyType:$beta, + AnyType:$reserve_space, + AnyType:$y + ); + let output = (outs + AnyType:$gamma_diff, + AnyType:$beta_diff, + AnyType:$dx, + Optional:$addend_diff + ); + let attrs = (ins + DefaultValuedAttr:$axis, + DefaultValuedAttr:$epsilon + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FusedBiasAddGeluOp : OneFlow_BaseOp<"fused_bias_add_gelu", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$a, + AnyType:$b + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FusedBiasAddGeluGradOp : OneFlow_BaseOp<"fused_bias_add_gelu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$a, + AnyType:$b, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FusedBiasAddMaskScaleOp : OneFlow_BaseOp<"fused_bias_add_mask_scale", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$a, + AnyType:$b, + AnyType:$mask, + Optional:$_add_to_output + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$axis, + DefaultValuedAttr:$scale + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_FusedCastScaleOp : OneFlow_BaseOp<"fused_cast_scale", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$scale_by_tensor + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + DefaultValuedAttr:$scale + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FusedScaleMaskSoftmaxOp : OneFlow_BaseOp<"fused_scale_mask_softmax", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$mask + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + DefaultValuedAttr:$scale_value, + DefaultValuedAttr:$mask_fill_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_FusedScaleMaskSoftmaxDropoutOp : OneFlow_BaseOp<"fused_scale_mask_softmax_dropout", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$mask, + AnyType:$dropout_mask + ); + let output = (outs + AnyType:$y, + AnyType:$softmax_y + ); + let attrs = (ins + DefaultValuedAttr:$scale_value, + DefaultValuedAttr:$mask_fill_value, + DefaultValuedAttr:$dropout_scale_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_FusedScaleMaskSoftmaxDropoutGradOp : OneFlow_BaseOp<"fused_scale_mask_softmax_dropout_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$softmax_y, + AnyType:$dy, + AnyType:$mask, + AnyType:$dropout_mask + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$scale_value, + DefaultValuedAttr:$dropout_scale_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FusedScaleMaskSoftmaxGradOp : OneFlow_BaseOp<"fused_scale_mask_softmax_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$y, + AnyType:$dy, + AnyType:$mask + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$scale_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FusedScaleTrilOp : OneFlow_BaseOp<"fused_scale_tril", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$diagonal, + DefaultValuedAttr:$floating_fill_value, + DefaultValuedAttr:$integer_fill_value, + DefaultValuedAttr:$is_floating_fill_value, + DefaultValuedAttr:$floating_scale_value, + DefaultValuedAttr:$integer_scale_value, + DefaultValuedAttr:$is_floating_scale_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FusedSelfAttentionQueryMulKeyAndValueOp : OneFlow_BaseOp<"fused_self_attention_query_mul_key_and_value", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$hidden_states + ); + let output = (outs + AnyType:$query_mul_key, + AnyType:$value + ); + let attrs = (ins + DefaultValuedAttr:$head_size, + DefaultValuedAttr:$alpha + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FusedSelfAttentionQueryMulKeyAndValueGradOp : OneFlow_BaseOp<"fused_self_attention_query_mul_key_and_value_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$query_mul_key_grad, + AnyType:$value_grad, + AnyType:$hidden_states + ); + let output = (outs + AnyType:$hidden_states_grad + ); + let attrs = (ins + DefaultValuedAttr:$alpha + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FusedTrilScaleSoftmaxMaskScaleOp : OneFlow_BaseOp<"fused_tril_scale_softmax_mask_scale", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$mask + ); + let output = (outs + AnyType:$y, + AnyType:$softmax_y + ); + let attrs = (ins + DefaultValuedAttr:$diagonal, + DefaultValuedAttr:$tril_fill_value, + DefaultValuedAttr:$tril_scale_value, + DefaultValuedAttr:$mask_scale_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_FusedTrilScaleSoftmaxMaskScaleGradOp : OneFlow_BaseOp<"fused_tril_scale_softmax_mask_scale_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$softmax_y, + AnyType:$dy, + AnyType:$mask + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$diagonal, + DefaultValuedAttr:$tril_scale_value, + DefaultValuedAttr:$mask_scale_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_NormalizationAddReluGradOp : OneFlow_BaseOp<"normalization_add_relu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy, + AnyType:$mean, + AnyType:$inv_variance, + AnyType:$gamma, + AnyType:$beta, + AnyType:$reserve_space, + AnyType:$y + ); + let output = (outs + AnyType:$gamma_diff, + AnyType:$beta_diff, + AnyType:$dx, + Optional:$addend_diff + ); + let attrs = (ins + DefaultValuedAttr:$axis, + DefaultValuedAttr:$epsilon + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_FUSED_OP_DEFINITIONS + +// Group: IDEMPOTENT +// abs, ceil, floor, ones_like, relu, rint, round, sign +// Total: 8 + +#ifdef GET_ONEFLOW_IDEMPOTENT_OP_DEFINITIONS + +def OneFlow_AbsOp : OneFlow_IdempotentBaseOp<"abs", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_CeilOp : OneFlow_IdempotentBaseOp<"ceil", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_FloorOp : OneFlow_IdempotentBaseOp<"floor", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_OnesLikeOp : OneFlow_IdempotentBaseOp<"ones_like", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let same_output_regst_num = 1; +} + +def OneFlow_ReluOp : OneFlow_IdempotentBaseOp<"relu", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_RintOp : OneFlow_IdempotentBaseOp<"rint", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_RoundOp : OneFlow_IdempotentBaseOp<"round", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_SignOp : OneFlow_IdempotentBaseOp<"sign", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +#endif // GET_ONEFLOW_IDEMPOTENT_OP_DEFINITIONS + +// Group: IDENTITY +// amp_white_identity, identity, identity_buffer, tuple_identity +// Total: 4 + +#ifdef GET_ONEFLOW_IDENTITY_OP_DEFINITIONS + +def OneFlow_AmpWhiteIdentityOp : OneFlow_BaseOp<"amp_white_identity", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_IdentityOp : OneFlow_BaseOp<"identity", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_IdentityBufferOp : OneFlow_BaseOp<"identity_buffer", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$buffer_size + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TupleIdentityOp : OneFlow_BaseOp<"tuple_identity", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + Variadic:$in + ); + let output = (outs + Variadic:$out + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_sbp_signature_infer_fn = 1; +} + +#endif // GET_ONEFLOW_IDENTITY_OP_DEFINITIONS + +// Group: IMAGE +// image_batch_align, image_decode, image_flip, image_random_crop, image_resize_keep_aspect_ratio, image_resize_to_fixed +// Total: 6 + +#ifdef GET_ONEFLOW_IMAGE_OP_DEFINITIONS + +def OneFlow_ImageBatchAlignOp : OneFlow_BaseOp<"image_batch_align", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + ShapeAttr:$shape, + OneFlow_DataType:$data_type, + DefaultValuedAttr:$alignment, + DefaultValuedAttr:$dynamic_out + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_output_arg_modify_fn = 1; +} + +def OneFlow_ImageDecodeOp : OneFlow_BaseOp<"image_decode", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$color_space, + OneFlow_DataType:$data_type + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ImageFlipOp : OneFlow_BaseOp<"image_flip", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in, + AnyType:$flip_code + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ImageRandomCropOp : OneFlow_BaseOp<"image_random_crop", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$num_attempts, + DefaultValuedAttr:$seed, + DefaultValuedAttr:$has_seed, + F32ArrayAttr:$random_area, + F32ArrayAttr:$random_aspect_ratio + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_ImageResizeKeepAspectRatioOp : OneFlow_BaseOp<"image_resize_keep_aspect_ratio", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out, + AnyType:$size, + AnyType:$scale + ); + let attrs = (ins + DefaultValuedAttr:$target_size, + DefaultValuedAttr:$min_size, + DefaultValuedAttr:$max_size, + DefaultValuedAttr:$resize_longer, + DefaultValuedAttr:$interpolation_type + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ImageResizeToFixedOp : OneFlow_BaseOp<"image_resize_to_fixed", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out, + AnyType:$scale + ); + let attrs = (ins + DefaultValuedAttr:$target_width, + DefaultValuedAttr:$target_height, + DefaultValuedAttr:$channels, + OneFlow_DataType:$data_type, + DefaultValuedAttr:$interpolation_type + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_IMAGE_OP_DEFINITIONS + +// Group: INDICES +// arg_sort, argmax, argwhere, batch_gather, dim_gather, dim_scatter_add, dim_scatter_add_like, dim_scatter_add_scalar, dim_scatter_mul, dim_scatter_mul_scalar, dim_scatter_update, dim_scatter_update_scalar, gather, gather_nd, generate_random_batch_permutation_indices, image_target_resize, logical_slice, scatter_nd, scatter_nd_like, slice, slice_grad, tensor_scatter_nd_add, tensor_scatter_nd_update, unsorted_batch_segment_sum, unsorted_segment_sum, unsorted_segment_sum_like, where, where_scalar_x, where_scalar_xy, where_scalar_y +// Total: 30 + +#ifdef GET_ONEFLOW_INDICES_OP_DEFINITIONS + +def OneFlow_ArgSortOp : OneFlow_BaseOp<"arg_sort", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + StrAttr:$direction + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ArgmaxOp : OneFlow_BaseOp<"argmax", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ArgwhereOp : OneFlow_BaseOp<"argwhere", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input + ); + let output = (outs + AnyType:$output, + AnyType:$output_size + ); + let attrs = (ins + OneFlow_DataType:$dtype + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BatchGatherOp : OneFlow_BaseOp<"batch_gather", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in, + AnyType:$indices + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_DimGatherOp : OneFlow_BaseOp<"dim_gather", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input, + AnyType:$index + ); + let output = (outs + AnyType:$output + ); + let attrs = (ins + DefaultValuedAttr:$dim + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_DimScatterAddOp : OneFlow_BaseOp<"dim_scatter_add", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input, + AnyType:$index, + AnyType:$src + ); + let output = (outs + AnyType:$output + ); + let attrs = (ins + DefaultValuedAttr:$dim + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_DimScatterAddLikeOp : OneFlow_BaseOp<"dim_scatter_add_like", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$like, + AnyType:$index, + AnyType:$src + ); + let output = (outs + AnyType:$output + ); + let attrs = (ins + DefaultValuedAttr:$dim + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_DimScatterAddScalarOp : OneFlow_BaseOp<"dim_scatter_add_scalar", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input, + AnyType:$index + ); + let output = (outs + AnyType:$output + ); + let attrs = (ins + DefaultValuedAttr:$src_scalar, + DefaultValuedAttr:$dim + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_DimScatterMulOp : OneFlow_BaseOp<"dim_scatter_mul", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input, + AnyType:$index, + AnyType:$src + ); + let output = (outs + AnyType:$output + ); + let attrs = (ins + DefaultValuedAttr:$dim + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_DimScatterMulScalarOp : OneFlow_BaseOp<"dim_scatter_mul_scalar", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input, + AnyType:$index + ); + let output = (outs + AnyType:$output + ); + let attrs = (ins + DefaultValuedAttr:$src_scalar, + DefaultValuedAttr:$dim + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_DimScatterUpdateOp : OneFlow_BaseOp<"dim_scatter_update", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input, + AnyType:$index, + AnyType:$src + ); + let output = (outs + AnyType:$output + ); + let attrs = (ins + DefaultValuedAttr:$dim + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_DimScatterUpdateScalarOp : OneFlow_BaseOp<"dim_scatter_update_scalar", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input, + AnyType:$index + ); + let output = (outs + AnyType:$output + ); + let attrs = (ins + DefaultValuedAttr:$src_scalar, + DefaultValuedAttr:$dim + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_GatherOp : OneFlow_BaseOp<"gather", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in, + AnyType:$indices + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_GatherNdOp : OneFlow_BaseOp<"gather_nd", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$params, + AnyType:$indices + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_GenerateRandomBatchPermutationIndicesOp : OneFlow_BaseOp<"generate_random_batch_permutation_indices", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + DefaultValuedAttr:$seed + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ImageTargetResizeOp : OneFlow_BaseOp<"image_target_resize", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out, + AnyType:$size, + AnyType:$scale + ); + let attrs = (ins + DefaultValuedAttr:$target_size, + DefaultValuedAttr:$max_size + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LogicalSliceOp : OneFlow_BaseOp<"logical_slice", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + SI64ArrayAttr:$start, + SI64ArrayAttr:$stop, + SI64ArrayAttr:$step + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScatterNdOp : OneFlow_BaseOp<"scatter_nd", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$indices, + AnyType:$updates + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + ShapeAttr:$shape + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_ScatterNdLikeOp : OneFlow_BaseOp<"scatter_nd_like", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$like, + AnyType:$indices, + AnyType:$updates + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SliceOp : OneFlow_BaseOp<"slice", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + SI64ArrayAttr:$start, + SI64ArrayAttr:$stop, + SI64ArrayAttr:$step + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SliceGradOp : OneFlow_BaseOp<"slice_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$like + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + SI64ArrayAttr:$start, + SI64ArrayAttr:$stop, + SI64ArrayAttr:$step + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_TensorScatterNdAddOp : OneFlow_BaseOp<"tensor_scatter_nd_add", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$params, + AnyType:$updates, + AnyType:$indices + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_TensorScatterNdUpdateOp : OneFlow_BaseOp<"tensor_scatter_nd_update", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$params, + AnyType:$updates, + AnyType:$indices + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_UnsortedBatchSegmentSumOp : OneFlow_BaseOp<"unsorted_batch_segment_sum", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$data, + AnyType:$segment_ids + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$num_segments + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_UnsortedSegmentSumOp : OneFlow_BaseOp<"unsorted_segment_sum", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$data, + AnyType:$segment_ids + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$axis, + DefaultValuedAttr:$num_segments + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_UnsortedSegmentSumLikeOp : OneFlow_BaseOp<"unsorted_segment_sum_like", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$data, + AnyType:$segment_ids, + AnyType:$like + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_WhereOp : OneFlow_BaseOp<"where", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$condition, + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_WhereScalarXOp : OneFlow_BaseOp<"where_scalar_x", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$condition, + AnyType:$y + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_WhereScalarXyOp : OneFlow_BaseOp<"where_scalar_xy", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$condition + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_x_int_operand, + DefaultValuedAttr:$has_x_float_operand, + DefaultValuedAttr:$has_y_int_operand, + DefaultValuedAttr:$has_y_float_operand, + DefaultValuedAttr:$x_int_operand, + DefaultValuedAttr:$x_float_operand, + DefaultValuedAttr:$y_int_operand, + DefaultValuedAttr:$y_float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_WhereScalarYOp : OneFlow_BaseOp<"where_scalar_y", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$condition, + AnyType:$x + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +#endif // GET_ONEFLOW_INDICES_OP_DEFINITIONS + +// Group: INVOLUTION +// negative, reciprocal +// Total: 2 + +#ifdef GET_ONEFLOW_INVOLUTION_OP_DEFINITIONS + +def OneFlow_NegativeOp : OneFlow_InvolutionBaseOp<"negative", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_ReciprocalOp : OneFlow_InvolutionBaseOp<"reciprocal", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +#endif // GET_ONEFLOW_INVOLUTION_OP_DEFINITIONS + +// Group: LOSS +// combined_margin_loss, combined_margin_loss_grad, ctc_loss, ctc_loss_grad, dynamic_loss_scale_schedule, kl_div_loss, kl_div_loss_grad, smooth_l1_loss, smooth_l1_loss_grad +// Total: 9 + +#ifdef GET_ONEFLOW_LOSS_OP_DEFINITIONS + +def OneFlow_CombinedMarginLossOp : OneFlow_BaseOp<"combined_margin_loss", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$label + ); + let output = (outs + AnyType:$y, + AnyType:$theta + ); + let attrs = (ins + DefaultValuedAttr:$m1, + DefaultValuedAttr:$m2, + DefaultValuedAttr:$m3, + DefaultValuedAttr:$depth + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_CombinedMarginLossGradOp : OneFlow_BaseOp<"combined_margin_loss_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$label, + AnyType:$theta + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$m1, + DefaultValuedAttr:$m2, + DefaultValuedAttr:$m3, + DefaultValuedAttr:$depth + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CtcLossOp : OneFlow_BaseOp<"ctc_loss", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$log_probs, + AnyType:$targets, + AnyType:$input_lengths, + AnyType:$target_lengths + ); + let output = (outs + AnyType:$loss, + AnyType:$alpha + ); + let attrs = (ins + DefaultValuedAttr:$max_target_length, + DefaultValuedAttr:$blank, + DefaultValuedAttr:$zero_infinity + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CtcLossGradOp : OneFlow_BaseOp<"ctc_loss_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$grad_out, + AnyType:$log_probs, + AnyType:$targets, + AnyType:$input_lengths, + AnyType:$target_lengths, + AnyType:$loss, + AnyType:$alpha + ); + let output = (outs + AnyType:$grad + ); + let attrs = (ins + DefaultValuedAttr:$max_target_length, + DefaultValuedAttr:$blank, + DefaultValuedAttr:$zero_infinity + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_DynamicLossScaleScheduleOp : OneFlow_BaseOp<"dynamic_loss_scale_schedule", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$count_not_finite, + AnyType:$loss_scale, + AnyType:$good_step_counter + ); + let attrs = (ins + DefaultValuedAttr:$increment_period, + DefaultValuedAttr:$multiplier + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_KlDivLossOp : OneFlow_BaseOp<"kl_div_loss", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input, + AnyType:$target + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$log_target + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_KlDivLossGradOp : OneFlow_BaseOp<"kl_div_loss_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input, + AnyType:$target, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$log_target + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SmoothL1LossOp : OneFlow_BaseOp<"smooth_l1_loss", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input, + AnyType:$target + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$beta + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_SmoothL1LossGradOp : OneFlow_BaseOp<"smooth_l1_loss_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input, + AnyType:$target, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$beta + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_LOSS_OP_DEFINITIONS + +// Group: MATH +// abs_grad, ceil_grad, erf, erf_grad, exp, exp_grad, expand_grad, expm1, expm1_grad, floor_grad, floordiv_x_grad, floordiv_y_grad, lgamma, lgamma_grad, log, log1p, log1p_grad, log2_grad, log_grad, log_sigmoid, log_sigmoid_grad, negative_grad, reciprocal_grad, reciprocal_no_nan, reciprocal_no_nan_grad, rint_grad, round_grad, rsqrt, rsqrt_grad, sigmoid_v2, sigmoid_v2_grad, sign_grad, softplus, softplus_grad, softsign_grad, sqrt, sqrt_grad, square, square_grad, xlogy_x_grad, xlogy_y_grad +// Total: 41 + +#ifdef GET_ONEFLOW_MATH_OP_DEFINITIONS + +def OneFlow_AbsGradOp : OneFlow_BaseOp<"abs_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CeilGradOp : OneFlow_BaseOp<"ceil_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ErfOp : OneFlow_BaseOp<"erf", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ErfGradOp : OneFlow_BaseOp<"erf_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ExpOp : OneFlow_BaseOp<"exp", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ExpGradOp : OneFlow_BaseOp<"exp_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ExpandGradOp : OneFlow_BaseOp<"expand_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + SI32ArrayAttr:$logical_out_shape, + SI32ArrayAttr:$logical_expand_shape + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_Expm1Op : OneFlow_BaseOp<"expm1", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_Expm1GradOp : OneFlow_BaseOp<"expm1_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FloorGradOp : OneFlow_BaseOp<"floor_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FloordivXGradOp : OneFlow_BaseOp<"floordiv_x_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y, + AnyType:$dz + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FloordivYGradOp : OneFlow_BaseOp<"floordiv_y_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y, + AnyType:$dz + ); + let output = (outs + AnyType:$dy + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LgammaOp : OneFlow_BaseOp<"lgamma", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LgammaGradOp : OneFlow_BaseOp<"lgamma_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LogOp : OneFlow_BaseOp<"log", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_Log1pOp : OneFlow_BaseOp<"log1p", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_Log1pGradOp : OneFlow_BaseOp<"log1p_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_Log2GradOp : OneFlow_BaseOp<"log2_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LogGradOp : OneFlow_BaseOp<"log_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LogSigmoidOp : OneFlow_BaseOp<"log_sigmoid", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LogSigmoidGradOp : OneFlow_BaseOp<"log_sigmoid_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_NegativeGradOp : OneFlow_BaseOp<"negative_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReciprocalGradOp : OneFlow_BaseOp<"reciprocal_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReciprocalNoNanOp : OneFlow_BaseOp<"reciprocal_no_nan", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReciprocalNoNanGradOp : OneFlow_BaseOp<"reciprocal_no_nan_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_RintGradOp : OneFlow_BaseOp<"rint_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_RoundGradOp : OneFlow_BaseOp<"round_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_RsqrtOp : OneFlow_BaseOp<"rsqrt", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_RsqrtGradOp : OneFlow_BaseOp<"rsqrt_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SigmoidV2Op : OneFlow_BaseOp<"sigmoid_v2", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SigmoidV2GradOp : OneFlow_BaseOp<"sigmoid_v2_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SignGradOp : OneFlow_BaseOp<"sign_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SoftplusOp : OneFlow_BaseOp<"softplus", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SoftplusGradOp : OneFlow_BaseOp<"softplus_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SoftsignGradOp : OneFlow_BaseOp<"softsign_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SqrtOp : OneFlow_BaseOp<"sqrt", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SqrtGradOp : OneFlow_BaseOp<"sqrt_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SquareOp : OneFlow_BaseOp<"square", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SquareGradOp : OneFlow_BaseOp<"square_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_XlogyXGradOp : OneFlow_BaseOp<"xlogy_x_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y, + AnyType:$dz + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_XlogyYGradOp : OneFlow_BaseOp<"xlogy_y_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y, + AnyType:$dz + ); + let output = (outs + AnyType:$dy + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_MATH_OP_DEFINITIONS + +// Group: MATMUL +// batch_matmul, broadcast_matmul, broadcast_matmul_grad_b, distributed_partial_fc_sample, distributed_partial_fc_sample_disable_boxing, erfc, erfc_grad, matmul +// Total: 8 + +#ifdef GET_ONEFLOW_MATMUL_OP_DEFINITIONS + +def OneFlow_BatchMatmulOp : OneFlow_BaseOp<"batch_matmul", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$a, + AnyType:$b, + Optional:$_add_to_output + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$transpose_a, + DefaultValuedAttr:$transpose_b, + DefaultValuedAttr:$alpha + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastMatmulOp : OneFlow_BaseOp<"broadcast_matmul", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$a, + AnyType:$b, + Optional:$_add_to_output + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$transpose_a, + DefaultValuedAttr:$transpose_b, + DefaultValuedAttr:$alpha + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BroadcastMatmulGradBOp : OneFlow_BaseOp<"broadcast_matmul_grad_b", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$a, + AnyType:$b, + Optional:$_add_to_output + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$alpha + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_DistributedPartialFcSampleOp : OneFlow_BaseOp<"distributed_partial_fc_sample", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$weight, + AnyType:$label + ); + let output = (outs + AnyType:$mapped_label, + AnyType:$sampled_label, + AnyType:$sampled_weight + ); + let attrs = (ins + DefaultValuedAttr:$num_sample, + DefaultValuedAttr:$seed + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_DistributedPartialFcSampleDisableBoxingOp : OneFlow_BaseOp<"distributed_partial_fc_sample_disable_boxing", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$sampled_weight_diff, + AnyType:$sampled_label + ); + let output = (outs + AnyType:$boxing_disabled_sampled_weight_diff, + AnyType:$boxing_disabled_sampled_label + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ErfcOp : OneFlow_BaseOp<"erfc", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ErfcGradOp : OneFlow_BaseOp<"erfc_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_MatmulOp : OneFlow_BaseOp<"matmul", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$a, + AnyType:$b, + Optional:$_add_to_output + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$transpose_a, + DefaultValuedAttr:$transpose_b, + DefaultValuedAttr:$alpha + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_MATMUL_OP_DEFINITIONS + +// Group: MISC +// CategoricalOrdinalEncode, add_n, arange, coin_flip, concat, constant, dropout, elementwise_maximum_backward, elementwise_minimum_backward, empty, eye, grid_sample_grad, multi_count_not_finite, multi_square_sum, nll, nll_grad, pow_x_grad, pow_y_grad, prelu_grad, randperm, recv, send, split_like, ssp_variable_proxy, tf_prelu_grad, uniform, uniform_int, unique_with_counts, xdivy_x_grad, xdivy_y_grad +// Total: 30 + +#ifdef GET_ONEFLOW_MISC_OP_DEFINITIONS + +def OneFlow_CategoricalOrdinalEncodeOp : OneFlow_BaseOp<"CategoricalOrdinalEncode", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$table, + AnyType:$size, + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$hash_precomputed + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_AddNOp : OneFlow_BaseOp<"add_n", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + Variadic:$in + ); + let output = (outs + AnyType:$out + ); + let hasCanonicalizer = 1; + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ArangeOp : OneFlow_BaseOp<"arange", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$integer_start, + DefaultValuedAttr:$integer_delta, + DefaultValuedAttr:$integer_limit, + DefaultValuedAttr:$float_start, + DefaultValuedAttr:$float_delta, + DefaultValuedAttr:$float_limit, + OneFlow_DataType:$dtype, + StrArrayAttr:$nd_sbp + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_CoinFlipOp : OneFlow_BaseOp<"coin_flip", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$probability, + DefaultValuedAttr:$batch_size, + DefaultValuedAttr:$seed, + DefaultValuedAttr:$has_seed, + StrArrayAttr:$nd_sbp + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_ConcatOp : OneFlow_BaseOp<"concat", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + Variadic:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$axis, + DefaultValuedAttr:$max_dim_size + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ConstantOp : OneFlow_BaseOp<"constant", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$floating_value, + DefaultValuedAttr:$integer_value, + DefaultValuedAttr:$is_floating_value, + OneFlow_DataType:$dtype, + ShapeAttr:$shape, + StrArrayAttr:$nd_sbp + ); + let same_output_regst_num = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_DropoutOp : OneFlow_BaseOp<"dropout", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in, + Optional:$_add_to_output + ); + let output = (outs + AnyType:$out, + AnyType:$mask + ); + let attrs = (ins + DefaultValuedAttr:$rate + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ElementwiseMaximumBackwardOp : OneFlow_BaseOp<"elementwise_maximum_backward", [NoSideEffect, AttrSizedResultSegments, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dz, + AnyType:$x, + AnyType:$y + ); + let output = (outs + Optional:$dx, + Optional:$dy + ); + let trait_attrs = (ins + I32ElementsAttr:$result_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ElementwiseMinimumBackwardOp : OneFlow_BaseOp<"elementwise_minimum_backward", [NoSideEffect, AttrSizedResultSegments, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dz, + AnyType:$x, + AnyType:$y + ); + let output = (outs + Optional:$dx, + Optional:$dy + ); + let trait_attrs = (ins + I32ElementsAttr:$result_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_EmptyOp : OneFlow_BaseOp<"empty", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let output = (outs + AnyType:$out + ); + let attrs = (ins + OneFlow_DataType:$dtype, + ShapeAttr:$shape, + StrArrayAttr:$nd_sbp + ); + let same_output_regst_num = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_EyeOp : OneFlow_BaseOp<"eye", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$rows, + DefaultValuedAttr:$cols, + OneFlow_DataType:$dtype, + StrArrayAttr:$nd_sbp + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_GridSampleGradOp : OneFlow_BaseOp<"grid_sample_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$doutput, + AnyType:$input, + AnyType:$grid + ); + let output = (outs + AnyType:$dinput, + AnyType:$dgrid + ); + let attrs = (ins + StrAttr:$interpolation_mode, + StrAttr:$padding_mode, + DefaultValuedAttr:$align_corners + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_MultiCountNotFiniteOp : OneFlow_BaseOp<"multi_count_not_finite", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + Variadic:$x + ); + let output = (outs + AnyType:$y + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_MultiSquareSumOp : OneFlow_BaseOp<"multi_square_sum", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + Variadic:$x + ); + let output = (outs + AnyType:$y + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_NllOp : OneFlow_BaseOp<"nll", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input, + AnyType:$target, + Optional:$weight + ); + let output = (outs + AnyType:$out, + AnyType:$total_weight + ); + let attrs = (ins + DefaultValuedAttr:$ignore_index + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_NllGradOp : OneFlow_BaseOp<"nll_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input, + AnyType:$target, + AnyType:$total_weight, + Optional:$weight, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$ignore_index + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_PowXGradOp : OneFlow_BaseOp<"pow_x_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y, + AnyType:$dz + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_PowYGradOp : OneFlow_BaseOp<"pow_y_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y, + AnyType:$dz + ); + let output = (outs + AnyType:$dy + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_PreluGradOp : OneFlow_BaseOp<"prelu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$x, + AnyType:$alpha + ); + let output = (outs + AnyType:$dx, + AnyType:$alpha_diff + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_RandpermOp : OneFlow_BaseOp<"randperm", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$n, + DefaultValuedAttr:$seed, + StrArrayAttr:$nd_sbp + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_RecvOp : OneFlow_BaseOp<"recv", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$src_process_id, + OneFlow_DataType:$dtype, + ShapeAttr:$shape, + StrAttr:$device_type, + DefaultValuedAttr:$device_id + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; +} + +def OneFlow_SendOp : OneFlow_BaseOp<"send", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let attrs = (ins + DefaultValuedAttr:$dst_process_id + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; +} + +def OneFlow_SplitLikeOp : OneFlow_BaseOp<"split_like", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in, + Variadic:$like + ); + let output = (outs + Variadic:$out + ); + let attrs = (ins + DefaultValuedAttr:$axis + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_SspVariableProxyOp : OneFlow_BaseOp<"ssp_variable_proxy", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$var + ); + let output = (outs + AnyType:$ref, + AnyType:$value + ); + let attrs = (ins + DefaultValuedAttr:$buffer_size + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_output_arg_modify_fn = 1; +} + +def OneFlow_TfPreluGradOp : OneFlow_BaseOp<"tf_prelu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$x, + AnyType:$alpha + ); + let output = (outs + AnyType:$dx, + AnyType:$alpha_diff + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UniformOp : OneFlow_BaseOp<"uniform", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$from, + DefaultValuedAttr:$to, + DefaultValuedAttr:$seed, + OneFlow_DataType:$dtype, + ShapeAttr:$shape, + StrArrayAttr:$nd_sbp + ); + let same_output_regst_num = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_UniformIntOp : OneFlow_BaseOp<"uniform_int", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$from, + DefaultValuedAttr:$to, + DefaultValuedAttr:$seed, + OneFlow_DataType:$dtype, + ShapeAttr:$shape, + StrArrayAttr:$nd_sbp + ); + let same_output_regst_num = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_UniqueWithCountsOp : OneFlow_BaseOp<"unique_with_counts", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y, + AnyType:$idx, + AnyType:$count, + AnyType:$num_unique + ); + let attrs = (ins + OneFlow_DataType:$out_idx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_XdivyXGradOp : OneFlow_BaseOp<"xdivy_x_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y, + AnyType:$dz + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_XdivyYGradOp : OneFlow_BaseOp<"xdivy_y_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y, + AnyType:$dz + ); + let output = (outs + AnyType:$dy + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_MISC_OP_DEFINITIONS + +// Group: NCCL +// _nccl_logical_2D_same_dim0_all2all, _nccl_logical_2D_same_dim0_all_gather, _nccl_logical_2D_same_dim0_all_gather_noncontinuous, _nccl_logical_2D_same_dim0_all_reduce, _nccl_logical_2D_same_dim1_all_reduce, _nccl_logical_all_gather, _nccl_logical_all_gather_noncontinuous, _nccl_logical_all_reduce, _nccl_logical_reduce_scatter, _nccl_logical_s2s +// Total: 10 + +#ifdef GET_ONEFLOW_NCCL_OP_DEFINITIONS + +def OneFlow__ncclLogical_2DSameDim0All2allOp : OneFlow_BaseOp<"_nccl_logical_2D_same_dim0_all2all", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$in_dim1_split_axis, + DefaultValuedAttr:$out_dim1_split_axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow__ncclLogical_2DSameDim0AllGatherOp : OneFlow_BaseOp<"_nccl_logical_2D_same_dim0_all_gather", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow__ncclLogical_2DSameDim0AllGatherNoncontinuousOp : OneFlow_BaseOp<"_nccl_logical_2D_same_dim0_all_gather_noncontinuous", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$in_dim1_split_axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow__ncclLogical_2DSameDim0AllReduceOp : OneFlow_BaseOp<"_nccl_logical_2D_same_dim0_all_reduce", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow__ncclLogical_2DSameDim1AllReduceOp : OneFlow_BaseOp<"_nccl_logical_2D_same_dim1_all_reduce", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow__ncclLogicalAllGatherOp : OneFlow_BaseOp<"_nccl_logical_all_gather", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow__ncclLogicalAllGatherNoncontinuousOp : OneFlow_BaseOp<"_nccl_logical_all_gather_noncontinuous", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$in_split_axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow__ncclLogicalAllReduceOp : OneFlow_BaseOp<"_nccl_logical_all_reduce", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow__ncclLogicalReduceScatterOp : OneFlow_BaseOp<"_nccl_logical_reduce_scatter", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow__ncclLogicalS2sOp : OneFlow_BaseOp<"_nccl_logical_s2s", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$in_split_axis, + DefaultValuedAttr:$out_split_axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +#endif // GET_ONEFLOW_NCCL_OP_DEFINITIONS + +// Group: NORMALIZATION +// crop_mirror_normalize_from_tensorbuffer, crop_mirror_normalize_from_uint8, image_normalize, l2_normalize, l2_normalize_grad, layer_norm, layer_norm_grad, layer_norm_param_grad, normal, normalization, normalization_grad +// Total: 11 + +#ifdef GET_ONEFLOW_NORMALIZATION_OP_DEFINITIONS + +def OneFlow_CropMirrorNormalizeFromTensorbufferOp : OneFlow_BaseOp<"crop_mirror_normalize_from_tensorbuffer", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in, + Optional:$mirror + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$color_space, + DefaultValuedAttr:$output_layout, + F32ArrayAttr:$mean, + F32ArrayAttr:$std, + DefaultValuedAttr:$crop_h, + DefaultValuedAttr:$crop_w, + DefaultValuedAttr:$crop_pos_x, + DefaultValuedAttr:$crop_pos_y, + OneFlow_DataType:$output_dtype + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CropMirrorNormalizeFromUint8Op : OneFlow_BaseOp<"crop_mirror_normalize_from_uint8", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in, + Optional:$mirror + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$color_space, + DefaultValuedAttr:$output_layout, + F32ArrayAttr:$mean, + F32ArrayAttr:$std, + DefaultValuedAttr:$crop_h, + DefaultValuedAttr:$crop_w, + DefaultValuedAttr:$crop_pos_x, + DefaultValuedAttr:$crop_pos_y, + OneFlow_DataType:$output_dtype + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ImageNormalizeOp : OneFlow_BaseOp<"image_normalize", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + F32ArrayAttr:$std, + F32ArrayAttr:$mean + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_L2NormalizeOp : OneFlow_BaseOp<"l2_normalize", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y, + AnyType:$square_x_sum + ); + let attrs = (ins + DefaultValuedAttr:$axis, + DefaultValuedAttr:$epsilon + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_L2NormalizeGradOp : OneFlow_BaseOp<"l2_normalize_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$y, + AnyType:$square_x_sum + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$axis, + DefaultValuedAttr:$epsilon + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LayerNormOp : OneFlow_BaseOp<"layer_norm", [NoSideEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + Optional:$beta, + Optional:$gamma + ); + let output = (outs + AnyType:$y, + AnyType:$mean, + AnyType:$inv_variance, + Optional:$normalized + ); + let attrs = (ins + DefaultValuedAttr:$center, + DefaultValuedAttr:$scale, + DefaultValuedAttr:$begin_norm_axis, + DefaultValuedAttr:$begin_params_axis, + DefaultValuedAttr:$epsilon + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LayerNormGradOp : OneFlow_BaseOp<"layer_norm_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$x, + AnyType:$mean, + AnyType:$inv_variance, + Optional:$_add_to_output + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$begin_norm_axis, + DefaultValuedAttr:$epsilon + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LayerNormParamGradOp : OneFlow_BaseOp<"layer_norm_param_grad", [NoSideEffect, AttrSizedOperandSegments, AttrSizedResultSegments, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + Optional:$normalized, + Optional:$gamma + ); + let output = (outs + Optional:$normalized_diff, + Optional:$beta_diff, + Optional:$gamma_diff, + Optional:$reduce_buf + ); + let attrs = (ins + DefaultValuedAttr:$begin_params_axis + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes, + I32ElementsAttr:$result_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_NormalOp : OneFlow_BaseOp<"normal", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$mean, + DefaultValuedAttr:$std, + DefaultValuedAttr:$seed, + OneFlow_DataType:$dtype, + ShapeAttr:$shape, + StrArrayAttr:$nd_sbp + ); + let same_output_regst_num = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_NormalizationOp : OneFlow_BaseOp<"normalization", [NoSideEffect, AttrSizedOperandSegments, AttrSizedResultSegments, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + Optional:$moving_mean, + Optional:$moving_variance, + AnyType:$gamma, + AnyType:$beta, + Optional:$_add_to_output + ); + let output = (outs + AnyType:$y, + Optional:$mean, + Optional:$inv_variance + ); + let attrs = (ins + DefaultValuedAttr:$axis, + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$training, + DefaultValuedAttr:$momentum + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes, + I32ElementsAttr:$result_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_NormalizationGradOp : OneFlow_BaseOp<"normalization_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy, + AnyType:$mean, + AnyType:$inv_variance, + AnyType:$gamma + ); + let output = (outs + AnyType:$gamma_diff, + AnyType:$beta_diff, + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$axis, + DefaultValuedAttr:$epsilon + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_NORMALIZATION_OP_DEFINITIONS + +// Group: OPTIMIZER +// adagrad_update, adam_bias_correction_factor, adam_update, indexed_slices_adam_update, indexed_slices_momentum_update, indexed_slices_sgd_update, lamb_update, lars_update, momentum_update, rmsprop_update, sgd_update, slice_update +// Total: 12 + +#ifdef GET_ONEFLOW_OPTIMIZER_OP_DEFINITIONS + +def OneFlow_AdagradUpdateOp : OneFlow_BaseOp<"adagrad_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$model, + AnyType:$model_diff, + Optional:$learning_rate, + Optional:$scale_by_tensor, + Optional:$skip_if, + Optional:$train_step, + AnyType:$sum + ); + let attrs = (ins + DefaultValuedAttr:$train_step_val, + DefaultValuedAttr:$learning_rate_val, + DefaultValuedAttr:$scale, + DefaultValuedAttr:$l1, + DefaultValuedAttr:$l2, + DefaultValuedAttr:$lr_decay, + DefaultValuedAttr:$weight_decay, + DefaultValuedAttr:$epsilon + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_AdamBiasCorrectionFactorOp : OneFlow_BaseOp<"adam_bias_correction_factor", [NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$train_step + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$beta + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AdamUpdateOp : OneFlow_BaseOp<"adam_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$model, + AnyType:$model_diff, + Optional:$learning_rate, + Optional:$scale_by_tensor, + Optional:$skip_if, + Optional:$bias_correction1, + Optional:$bias_correction2, + AnyType:$m, + AnyType:$v, + AnyType:$max_v + ); + let attrs = (ins + DefaultValuedAttr:$learning_rate_val, + DefaultValuedAttr:$bias_correction1_val, + DefaultValuedAttr:$bias_correction2_val, + DefaultValuedAttr:$scale, + DefaultValuedAttr:$l1, + DefaultValuedAttr:$l2, + DefaultValuedAttr:$beta1, + DefaultValuedAttr:$beta2, + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$weight_decay, + DefaultValuedAttr:$amsgrad, + DefaultValuedAttr:$do_bias_correction + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_IndexedSlicesAdamUpdateOp : OneFlow_BaseOp<"indexed_slices_adam_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$model, + AnyType:$model_diff_indices, + AnyType:$model_diff_values, + AnyType:$learning_rate, + Optional:$bias_correction1, + Optional:$bias_correction2, + AnyType:$m, + AnyType:$v, + AnyType:$max_v + ); + let attrs = (ins + DefaultValuedAttr:$learning_rate_val, + DefaultValuedAttr:$beta1, + DefaultValuedAttr:$beta2, + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$weight_decay, + DefaultValuedAttr:$amsgrad, + DefaultValuedAttr:$do_bias_correction + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_IndexedSlicesMomentumUpdateOp : OneFlow_BaseOp<"indexed_slices_momentum_update", [NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$model, + AnyType:$model_diff_indices, + AnyType:$model_diff_values, + AnyType:$learning_rate, + AnyType:$momentum + ); + let attrs = (ins + DefaultValuedAttr:$beta, + DefaultValuedAttr:$weight_decay + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_IndexedSlicesSgdUpdateOp : OneFlow_BaseOp<"indexed_slices_sgd_update", [NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$model, + AnyType:$model_diff_indices, + AnyType:$model_diff_values, + AnyType:$learning_rate + ); + let attrs = (ins + DefaultValuedAttr:$weight_decay + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_LambUpdateOp : OneFlow_BaseOp<"lamb_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$m, + AnyType:$v, + AnyType:$beta1_t, + AnyType:$beta2_t, + AnyType:$model, + AnyType:$model_diff, + AnyType:$learning_rate, + Optional:$scale_by_tensor, + Optional:$skip_if + ); + let attrs = (ins + DefaultValuedAttr:$beta1, + DefaultValuedAttr:$beta2, + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$scale, + DefaultValuedAttr:$l1, + DefaultValuedAttr:$l2, + DefaultValuedAttr:$weight_decay + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_LarsUpdateOp : OneFlow_BaseOp<"lars_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$model, + AnyType:$model_diff, + AnyType:$learning_rate, + AnyType:$momentum, + Optional:$scale_by_tensor, + Optional:$skip_if + ); + let attrs = (ins + DefaultValuedAttr:$scale, + DefaultValuedAttr:$l1, + DefaultValuedAttr:$l2, + DefaultValuedAttr:$momentum_beta, + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$lars_coefficient, + DefaultValuedAttr:$weight_decay + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_MomentumUpdateOp : OneFlow_BaseOp<"momentum_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$model, + AnyType:$model_diff, + AnyType:$momentum, + Optional:$learning_rate, + Optional:$scale_by_tensor, + Optional:$skip_if + ); + let attrs = (ins + DefaultValuedAttr:$learning_rate_val, + DefaultValuedAttr:$scale, + DefaultValuedAttr:$l1, + DefaultValuedAttr:$l2, + DefaultValuedAttr:$beta, + DefaultValuedAttr:$weight_decay + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_RmspropUpdateOp : OneFlow_BaseOp<"rmsprop_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$model, + AnyType:$model_diff, + Optional:$learning_rate, + Optional:$scale_by_tensor, + Optional:$skip_if, + AnyType:$mean_square, + Optional:$mean_gradient + ); + let attrs = (ins + DefaultValuedAttr:$learning_rate_val, + DefaultValuedAttr:$scale, + DefaultValuedAttr:$l1, + DefaultValuedAttr:$l2, + DefaultValuedAttr:$centered, + DefaultValuedAttr:$epsilon, + DefaultValuedAttr:$decay_rate, + DefaultValuedAttr:$weight_decay + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_SgdUpdateOp : OneFlow_BaseOp<"sgd_update", [NoGrad, AttrSizedOperandSegments, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$model, + AnyType:$model_diff, + Optional:$learning_rate, + Optional:$scale_by_tensor, + Optional:$skip_if + ); + let attrs = (ins + DefaultValuedAttr:$learning_rate_val, + DefaultValuedAttr:$scale, + DefaultValuedAttr:$l1, + DefaultValuedAttr:$l2, + DefaultValuedAttr:$weight_decay + ); + let trait_attrs = (ins + I32ElementsAttr:$operand_segment_sizes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_SliceUpdateOp : OneFlow_BaseOp<"slice_update", [DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$update + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + SI64ArrayAttr:$start, + SI64ArrayAttr:$stop, + SI64ArrayAttr:$step + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_OPTIMIZER_OP_DEFINITIONS + +// Group: PADDING +// constant_pad1d, constant_pad1d_grad, constant_pad2d, constant_pad2d_grad, constant_pad3d, constant_pad3d_grad, pad, pad_grad, reflection_pad2d, reflection_pad2d_grad, replication_pad2d, replication_pad2d_grad, same_padding, same_padding_grad +// Total: 14 + +#ifdef GET_ONEFLOW_PADDING_OP_DEFINITIONS + +def OneFlow_ConstantPad1DOp : OneFlow_BaseOp<"constant_pad1d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + SI64ArrayAttr:$padding, + DefaultValuedAttr:$floating_value, + DefaultValuedAttr:$integral_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_ConstantPad1DGradOp : OneFlow_BaseOp<"constant_pad1d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + SI64ArrayAttr:$padding, + DefaultValuedAttr:$floating_value, + DefaultValuedAttr:$integral_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ConstantPad2DOp : OneFlow_BaseOp<"constant_pad2d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + SI64ArrayAttr:$padding, + DefaultValuedAttr:$floating_value, + DefaultValuedAttr:$integral_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_ConstantPad2DGradOp : OneFlow_BaseOp<"constant_pad2d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + SI64ArrayAttr:$padding, + DefaultValuedAttr:$floating_value, + DefaultValuedAttr:$integral_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ConstantPad3DOp : OneFlow_BaseOp<"constant_pad3d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + SI64ArrayAttr:$padding, + DefaultValuedAttr:$floating_value, + DefaultValuedAttr:$integral_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_ConstantPad3DGradOp : OneFlow_BaseOp<"constant_pad3d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + SI64ArrayAttr:$padding, + DefaultValuedAttr:$floating_value, + DefaultValuedAttr:$integral_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_PadOp : OneFlow_BaseOp<"pad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + SI64ArrayAttr:$padding_before, + SI64ArrayAttr:$padding_after, + DefaultValuedAttr:$floating_constant_value, + DefaultValuedAttr:$integral_constant_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_PadGradOp : OneFlow_BaseOp<"pad_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + SI64ArrayAttr:$padding_before, + SI64ArrayAttr:$padding_after, + DefaultValuedAttr:$floating_constant_value, + DefaultValuedAttr:$integral_constant_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReflectionPad2DOp : OneFlow_BaseOp<"reflection_pad2d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + SI64ArrayAttr:$padding + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_ReflectionPad2DGradOp : OneFlow_BaseOp<"reflection_pad2d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + SI64ArrayAttr:$padding + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReplicationPad2DOp : OneFlow_BaseOp<"replication_pad2d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + SI64ArrayAttr:$padding + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_ReplicationPad2DGradOp : OneFlow_BaseOp<"replication_pad2d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + SI64ArrayAttr:$padding + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SamePaddingOp : OneFlow_BaseOp<"same_padding", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + StrAttr:$padding, + StrAttr:$data_format, + SI32ArrayAttr:$kernel_size, + SI32ArrayAttr:$strides, + SI32ArrayAttr:$dilation_rate + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SamePaddingGradOp : OneFlow_BaseOp<"same_padding_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x_like, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + StrAttr:$padding, + StrAttr:$data_format, + SI32ArrayAttr:$kernel_size, + SI32ArrayAttr:$strides, + SI32ArrayAttr:$dilation_rate + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_PADDING_OP_DEFINITIONS + +// Group: PARALLEL_CAST +// hierarchical_parallel_cast, hierarchical_parallel_cast_like, parallel_cast +// Total: 3 + +#ifdef GET_ONEFLOW_PARALLEL_CAST_OP_DEFINITIONS + +def OneFlow_HierarchicalParallelCastOp : OneFlow_BaseOp<"hierarchical_parallel_cast", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + StrArrayAttr:$nd_sbp, + StrAttr:$grad_mode, + StrArrayAttr:$grad_nd_sbp + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_HierarchicalParallelCastLikeOp : OneFlow_BaseOp<"hierarchical_parallel_cast_like", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in, + AnyType:$like + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_ParallelCastOp : OneFlow_BaseOp<"parallel_cast", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + StrAttr:$sbp_parallel, + StrAttr:$grad_sbp_parallel + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_sbp_signature_infer_fn = 1; +} + +#endif // GET_ONEFLOW_PARALLEL_CAST_OP_DEFINITIONS + +// Group: POOL +// adaptive_avg_pool1d, adaptive_avg_pool1d_grad, adaptive_avg_pool2d, adaptive_avg_pool2d_grad, adaptive_avg_pool3d, adaptive_avg_pool3d_grad, avgpool_1d, avgpool_1d_grad, avgpool_2d, avgpool_2d_grad, avgpool_3d, avgpool_3d_grad, maxpool_1d, maxpool_1d_grad, maxpool_2d, maxpool_2d_grad, maxpool_3d, maxpool_3d_grad, tf_avg_pool_1d, tf_avg_pool_1d_grad, tf_avg_pool_2d, tf_avg_pool_2d_grad, tf_avg_pool_3d, tf_avg_pool_3d_grad, tf_max_pool_1d, tf_max_pool_1d_grad, tf_max_pool_2d, tf_max_pool_2d_grad, tf_max_pool_3d, tf_max_pool_3d_grad +// Total: 30 + +#ifdef GET_ONEFLOW_POOL_OP_DEFINITIONS + +def OneFlow_AdaptiveAvgPool1DOp : OneFlow_AdaptivePoolBaseOp<"adaptive_avg_pool1d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_AdaptiveAvgPool1DGradOp : OneFlow_AdaptivePoolGradBaseOp<"adaptive_avg_pool1d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_AdaptiveAvgPool2DOp : OneFlow_AdaptivePoolBaseOp<"adaptive_avg_pool2d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_AdaptiveAvgPool2DGradOp : OneFlow_AdaptivePoolGradBaseOp<"adaptive_avg_pool2d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_AdaptiveAvgPool3DOp : OneFlow_AdaptivePoolBaseOp<"adaptive_avg_pool3d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_AdaptiveAvgPool3DGradOp : OneFlow_AdaptivePoolGradBaseOp<"adaptive_avg_pool3d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_AvgPool1DOp : OneFlow_AvgPoolBaseOp<"avgpool_1d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_AvgPool1DGradOp : OneFlow_AvgPoolGradBaseOp<"avgpool_1d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_AvgPool2DOp : OneFlow_AvgPoolBaseOp<"avgpool_2d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_AvgPool2DGradOp : OneFlow_AvgPoolGradBaseOp<"avgpool_2d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_AvgPool3DOp : OneFlow_AvgPoolBaseOp<"avgpool_3d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_AvgPool3DGradOp : OneFlow_AvgPoolGradBaseOp<"avgpool_3d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_MaxPool1DOp : OneFlow_MaxPoolBaseOp<"maxpool_1d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_MaxPool1DGradOp : OneFlow_MaxPoolGradBaseOp<"maxpool_1d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_MaxPool2DOp : OneFlow_MaxPoolBaseOp<"maxpool_2d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_MaxPool2DGradOp : OneFlow_MaxPoolGradBaseOp<"maxpool_2d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_MaxPool3DOp : OneFlow_MaxPoolBaseOp<"maxpool_3d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_MaxPool3DGradOp : OneFlow_MaxPoolGradBaseOp<"maxpool_3d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_TfAvgPool1DOp : OneFlow_TFPoolBaseOp<"tf_avg_pool_1d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_TfAvgPool1DGradOp : OneFlow_TFPoolGradBaseOp<"tf_avg_pool_1d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_TfAvgPool2DOp : OneFlow_TFPoolBaseOp<"tf_avg_pool_2d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_TfAvgPool2DGradOp : OneFlow_TFPoolGradBaseOp<"tf_avg_pool_2d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_TfAvgPool3DOp : OneFlow_TFPoolBaseOp<"tf_avg_pool_3d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_TfAvgPool3DGradOp : OneFlow_TFPoolGradBaseOp<"tf_avg_pool_3d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_TfMaxPool1DOp : OneFlow_TFPoolBaseOp<"tf_max_pool_1d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_TfMaxPool1DGradOp : OneFlow_TFPoolGradBaseOp<"tf_max_pool_1d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_TfMaxPool2DOp : OneFlow_TFPoolBaseOp<"tf_max_pool_2d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_TfMaxPool2DGradOp : OneFlow_TFPoolGradBaseOp<"tf_max_pool_2d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_TfMaxPool3DOp : OneFlow_TFPoolBaseOp<"tf_max_pool_3d", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +def OneFlow_TfMaxPool3DGradOp : OneFlow_TFPoolGradBaseOp<"tf_max_pool_3d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> {} + +#endif // GET_ONEFLOW_POOL_OP_DEFINITIONS + +// Group: QUANTIZATION +// fake_quantization, min_max_observer, moving_average_min_max_observer, quantization +// Total: 4 + +#ifdef GET_ONEFLOW_QUANTIZATION_OP_DEFINITIONS + +def OneFlow_FakeQuantizationOp : OneFlow_BaseOp<"fake_quantization", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in, + AnyType:$scale, + AnyType:$zero_point + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$quantization_formula, + DefaultValuedAttr:$quantization_bit, + DefaultValuedAttr:$quantization_scheme + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_MinMaxObserverOp : OneFlow_BaseOp<"min_max_observer", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$scale, + AnyType:$zero_point + ); + let attrs = (ins + DefaultValuedAttr:$quantization_formula, + DefaultValuedAttr:$quantization_bit, + DefaultValuedAttr:$quantization_scheme, + DefaultValuedAttr:$per_layer_quantization + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_MovingAverageMinMaxObserverOp : OneFlow_BaseOp<"moving_average_min_max_observer", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in, + AnyType:$current_train_step, + AnyType:$moving_max, + AnyType:$moving_min + ); + let output = (outs + AnyType:$scale, + AnyType:$zero_point + ); + let attrs = (ins + DefaultValuedAttr:$training, + DefaultValuedAttr:$quantization_formula, + DefaultValuedAttr:$stop_update_after_iters, + DefaultValuedAttr:$quantization_bit, + DefaultValuedAttr:$quantization_scheme, + DefaultValuedAttr:$momentum + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_QuantizationOp : OneFlow_BaseOp<"quantization", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in, + AnyType:$scale, + AnyType:$zero_point + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$quantization_formula, + DefaultValuedAttr:$quantization_bit, + DefaultValuedAttr:$quantization_scheme + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +#endif // GET_ONEFLOW_QUANTIZATION_OP_DEFINITIONS + +// Group: REDUCE +// indexed_slices_reduce_sum, reduce_all, reduce_any, reduce_max, reduce_max_device_stage, reduce_max_device_stage_grad, reduce_max_global_stage, reduce_max_global_stage_grad, reduce_min, reduce_min_device_stage, reduce_min_device_stage_grad, reduce_min_global_stage, reduce_min_global_stage_grad, reduce_prod, reduce_sum, reduce_sum_like +// Total: 16 + +#ifdef GET_ONEFLOW_REDUCE_OP_DEFINITIONS + +def OneFlow_IndexedSlicesReduceSumOp : OneFlow_BaseOp<"indexed_slices_reduce_sum", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x_indices, + AnyType:$x_values + ); + let output = (outs + AnyType:$y_indices, + AnyType:$y_values, + AnyType:$num_unique + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceAllOp : OneFlow_BaseOp<"reduce_all", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input_tensor + ); + let output = (outs + AnyType:$output_tensor + ); + let attrs = (ins + SI32ArrayAttr:$axis, + DefaultValuedAttr:$keepdims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceAnyOp : OneFlow_BaseOp<"reduce_any", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input_tensor + ); + let output = (outs + AnyType:$output_tensor + ); + let attrs = (ins + SI32ArrayAttr:$axis, + DefaultValuedAttr:$keepdims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceMaxOp : OneFlow_BaseOp<"reduce_max", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input_tensor + ); + let output = (outs + AnyType:$output_tensor + ); + let attrs = (ins + SI32ArrayAttr:$axis, + DefaultValuedAttr:$keepdims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceMaxDeviceStageOp : OneFlow_BaseOp<"reduce_max_device_stage", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out, + AnyType:$mask, + AnyType:$count + ); + let attrs = (ins + SI32ArrayAttr:$axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceMaxDeviceStageGradOp : OneFlow_BaseOp<"reduce_max_device_stage_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$out_diff, + AnyType:$mask, + AnyType:$count + ); + let output = (outs + AnyType:$in_diff + ); + let attrs = (ins + SI32ArrayAttr:$axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceMaxGlobalStageOp : OneFlow_BaseOp<"reduce_max_global_stage", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in, + AnyType:$device_count + ); + let output = (outs + AnyType:$out, + AnyType:$mask + ); + let attrs = (ins + SI32ArrayAttr:$axis, + DefaultValuedAttr:$keepdims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_ReduceMaxGlobalStageGradOp : OneFlow_BaseOp<"reduce_max_global_stage_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$out_diff, + AnyType:$mask, + AnyType:$device_count + ); + let output = (outs + AnyType:$in_diff + ); + let attrs = (ins + SI32ArrayAttr:$axis, + DefaultValuedAttr:$keepdims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceMinOp : OneFlow_BaseOp<"reduce_min", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input_tensor + ); + let output = (outs + AnyType:$output_tensor + ); + let attrs = (ins + SI32ArrayAttr:$axis, + DefaultValuedAttr:$keepdims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceMinDeviceStageOp : OneFlow_BaseOp<"reduce_min_device_stage", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out, + AnyType:$mask, + AnyType:$count + ); + let attrs = (ins + SI32ArrayAttr:$axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceMinDeviceStageGradOp : OneFlow_BaseOp<"reduce_min_device_stage_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$out_diff, + AnyType:$mask, + AnyType:$count + ); + let output = (outs + AnyType:$in_diff + ); + let attrs = (ins + SI32ArrayAttr:$axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceMinGlobalStageOp : OneFlow_BaseOp<"reduce_min_global_stage", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in, + AnyType:$device_count + ); + let output = (outs + AnyType:$out, + AnyType:$mask + ); + let attrs = (ins + SI32ArrayAttr:$axis, + DefaultValuedAttr:$keepdims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_ReduceMinGlobalStageGradOp : OneFlow_BaseOp<"reduce_min_global_stage_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$out_diff, + AnyType:$mask, + AnyType:$device_count + ); + let output = (outs + AnyType:$in_diff + ); + let attrs = (ins + SI32ArrayAttr:$axis, + DefaultValuedAttr:$keepdims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceProdOp : OneFlow_BaseOp<"reduce_prod", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input_tensor + ); + let output = (outs + AnyType:$output_tensor + ); + let attrs = (ins + SI32ArrayAttr:$axis, + DefaultValuedAttr:$keepdims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceSumOp : OneFlow_BaseOp<"reduce_sum", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input_tensor + ); + let output = (outs + AnyType:$output_tensor + ); + let attrs = (ins + SI32ArrayAttr:$axis, + DefaultValuedAttr:$keepdims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ReduceSumLikeOp : OneFlow_BaseOp<"reduce_sum_like", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$like + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + SI32ArrayAttr:$axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +#endif // GET_ONEFLOW_REDUCE_OP_DEFINITIONS + +// Group: RESHAPE +// reshape, reshape_like +// Total: 2 + +#ifdef GET_ONEFLOW_RESHAPE_OP_DEFINITIONS + +def OneFlow_ReshapeOp : OneFlow_BaseOp<"reshape", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + ShapeAttr:$shape + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_ReshapeLikeOp : OneFlow_BaseOp<"reshape_like", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in, + AnyType:$like + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +#endif // GET_ONEFLOW_RESHAPE_OP_DEFINITIONS + +// Group: SCALAR +// clip_by_scalar, clip_by_scalar_grad, clip_by_scalar_max, clip_by_scalar_max_grad, clip_by_scalar_min, clip_by_scalar_min_grad, scalar_add, scalar_add_by_tensor, scalar_div_by_tensor, scalar_floordiv, scalar_fmod, scalar_logical_and, scalar_logical_equal, scalar_logical_greater, scalar_logical_greater_equal, scalar_logical_less, scalar_logical_less_equal, scalar_logical_not_equal, scalar_logical_or, scalar_logical_xor, scalar_mul, scalar_mul_by_tensor, scalar_pow, scalar_pow_grad, scalar_sub_by_tensor +// Total: 25 + +#ifdef GET_ONEFLOW_SCALAR_OP_DEFINITIONS + +def OneFlow_ClipByScalarOp : OneFlow_BaseOp<"clip_by_scalar", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + DefaultValuedAttr:$floating_min, + DefaultValuedAttr:$integral_min, + DefaultValuedAttr:$floating_max, + DefaultValuedAttr:$integral_max + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ClipByScalarGradOp : OneFlow_BaseOp<"clip_by_scalar_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$x + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$floating_min, + DefaultValuedAttr:$integral_min, + DefaultValuedAttr:$floating_max, + DefaultValuedAttr:$integral_max + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ClipByScalarMaxOp : OneFlow_BaseOp<"clip_by_scalar_max", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + DefaultValuedAttr:$floating_max, + DefaultValuedAttr:$integral_max + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ClipByScalarMaxGradOp : OneFlow_BaseOp<"clip_by_scalar_max_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$x + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$floating_max, + DefaultValuedAttr:$integral_max + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ClipByScalarMinOp : OneFlow_BaseOp<"clip_by_scalar_min", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + DefaultValuedAttr:$floating_min, + DefaultValuedAttr:$integral_min + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ClipByScalarMinGradOp : OneFlow_BaseOp<"clip_by_scalar_min_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$x + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$floating_min, + DefaultValuedAttr:$integral_min + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarAddOp : OneFlow_BaseOp<"scalar_add", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarAddByTensorOp : OneFlow_BaseOp<"scalar_add_by_tensor", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$scalar + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarDivByTensorOp : OneFlow_BaseOp<"scalar_div_by_tensor", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$scalar + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarFloordivOp : OneFlow_BaseOp<"scalar_floordiv", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarFmodOp : OneFlow_BaseOp<"scalar_fmod", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarLogicalAndOp : OneFlow_BaseOp<"scalar_logical_and", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarLogicalEqualOp : OneFlow_BaseOp<"scalar_logical_equal", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarLogicalGreaterOp : OneFlow_BaseOp<"scalar_logical_greater", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarLogicalGreaterEqualOp : OneFlow_BaseOp<"scalar_logical_greater_equal", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarLogicalLessOp : OneFlow_BaseOp<"scalar_logical_less", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarLogicalLessEqualOp : OneFlow_BaseOp<"scalar_logical_less_equal", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarLogicalNotEqualOp : OneFlow_BaseOp<"scalar_logical_not_equal", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarLogicalOrOp : OneFlow_BaseOp<"scalar_logical_or", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarLogicalXorOp : OneFlow_BaseOp<"scalar_logical_xor", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarMulOp : OneFlow_BaseOp<"scalar_mul", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarMulByTensorOp : OneFlow_BaseOp<"scalar_mul_by_tensor", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$scalar + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarPowOp : OneFlow_BaseOp<"scalar_pow", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarPowGradOp : OneFlow_BaseOp<"scalar_pow_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$has_int_operand, + DefaultValuedAttr:$has_float_operand, + DefaultValuedAttr:$int_operand, + DefaultValuedAttr:$float_operand + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ScalarSubByTensorOp : OneFlow_BaseOp<"scalar_sub_by_tensor", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$scalar + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_SCALAR_OP_DEFINITIONS + +// Group: SOFTMAX +// log_softmax, log_softmax_grad, softmax, softmax_cross_entropy, softmax_cross_entropy_grad, softmax_grad, sparse_softmax_cross_entropy, sparse_softmax_cross_entropy_grad, sparse_softmax_cross_entropy_ms, sparse_softmax_cross_entropy_ms_grad +// Total: 10 + +#ifdef GET_ONEFLOW_SOFTMAX_OP_DEFINITIONS + +def OneFlow_LogSoftmaxOp : OneFlow_BaseOp<"log_softmax", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$prob + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LogSoftmaxGradOp : OneFlow_BaseOp<"log_softmax_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$prob, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SoftmaxOp : OneFlow_BaseOp<"softmax", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SoftmaxCrossEntropyOp : OneFlow_BaseOp<"softmax_cross_entropy", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$prediction, + AnyType:$label + ); + let output = (outs + AnyType:$prob, + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_SoftmaxCrossEntropyGradOp : OneFlow_BaseOp<"softmax_cross_entropy_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$label, + AnyType:$prob + ); + let output = (outs + AnyType:$prediction_diff + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SoftmaxGradOp : OneFlow_BaseOp<"softmax_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$y, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SparseSoftmaxCrossEntropyOp : OneFlow_BaseOp<"sparse_softmax_cross_entropy", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$prediction, + AnyType:$label + ); + let output = (outs + AnyType:$prob, + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$depth + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_SparseSoftmaxCrossEntropyGradOp : OneFlow_BaseOp<"sparse_softmax_cross_entropy_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$label, + AnyType:$dy, + AnyType:$prob + ); + let output = (outs + AnyType:$prediction_diff + ); + let attrs = (ins + DefaultValuedAttr:$depth + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SparseSoftmaxCrossEntropyMsOp : OneFlow_BaseOp<"sparse_softmax_cross_entropy_ms", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$prediction, + AnyType:$label + ); + let output = (outs + AnyType:$prob, + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$depth + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_SparseSoftmaxCrossEntropyMsGradOp : OneFlow_BaseOp<"sparse_softmax_cross_entropy_ms_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$label, + AnyType:$dy, + AnyType:$prob + ); + let output = (outs + AnyType:$prediction_diff + ); + let attrs = (ins + DefaultValuedAttr:$depth + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_SOFTMAX_OP_DEFINITIONS + +// Group: SUMMARY +// create_summary_writer, flush_summary_writer, summary_write_histogram, summary_write_image, summary_write_pb, summary_write_scalar +// Total: 6 + +#ifdef GET_ONEFLOW_SUMMARY_OP_DEFINITIONS + +def OneFlow_CreateSummaryWriterOp : OneFlow_BaseOp<"create_summary_writer", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let attrs = (ins + StrAttr:$logdir + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FlushSummaryWriterOp : OneFlow_BaseOp<"flush_summary_writer", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SummaryWriteHistogramOp : OneFlow_BaseOp<"summary_write_histogram", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in, + AnyType:$step, + AnyType:$tag + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SummaryWriteImageOp : OneFlow_BaseOp<"summary_write_image", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in, + AnyType:$step, + AnyType:$tag + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SummaryWritePbOp : OneFlow_BaseOp<"summary_write_pb", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in, + AnyType:$step + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SummaryWriteScalarOp : OneFlow_BaseOp<"summary_write_scalar", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in, + AnyType:$step, + AnyType:$tag + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_SUMMARY_OP_DEFINITIONS + +// Group: TENSOR_BUFFER +// gen_tensor_buffer, tensor_buffer_to_list_of_tensors, tensor_buffer_to_list_of_tensors_v2, tensor_buffer_to_tensor, tensor_to_tensor_buffer +// Total: 5 + +#ifdef GET_ONEFLOW_TENSOR_BUFFER_OP_DEFINITIONS + +def OneFlow_GenTensorBufferOp : OneFlow_BaseOp<"gen_tensor_buffer", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let output = (outs + AnyType:$out + ); + let attrs = (ins + ShapeAttr:$shape, + ShapeArrayAttr:$shape_list, + F32ArrayAttr:$value_list, + OneFlow_DataType:$data_type, + DefaultValuedAttr:$dynamic_out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TensorBufferToListOfTensorsOp : OneFlow_BaseOp<"tensor_buffer_to_list_of_tensors", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + Variadic:$out + ); + let attrs = (ins + ShapeAttr:$out_shape, + OneFlow_DataType:$out_dtype, + DefaultValuedAttr:$dynamic_out + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_output_arg_modify_fn = 1; +} + +def OneFlow_TensorBufferToListOfTensorsV2Op : OneFlow_BaseOp<"tensor_buffer_to_list_of_tensors_v2", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + Variadic:$out + ); + let attrs = (ins + ShapeArrayAttr:$out_shapes, + DTArrayAttr:$out_dtypes, + DefaultValuedAttr:$dynamic_out + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_output_arg_modify_fn = 1; +} + +def OneFlow_TensorBufferToTensorOp : OneFlow_BaseOp<"tensor_buffer_to_tensor", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + ShapeAttr:$instance_shape, + OneFlow_DataType:$dtype + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TensorToTensorBufferOp : OneFlow_BaseOp<"tensor_to_tensor_buffer", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$instance_dims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_TENSOR_BUFFER_OP_DEFINITIONS + +// Group: TEST +// TestDataTypeAttr, TestDynamicSource, TestListDataTypeAndListShapeAndListStringAttr, TestMultiInput, TestMultiInputGrad, TestMultiOutputOrder, TestRandomSource, TestReshape, TestSource, TestSourceMultiGpuFixedOutNum, ccrelu, ccrelu_grad, cpu_only_relu_test, test_user_op_attr_auto_type +// Total: 14 + +#ifdef GET_ONEFLOW_TEST_OP_DEFINITIONS + +def OneFlow_TestDataTypeAttrOp : OneFlow_BaseOp<"TestDataTypeAttr", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + OneFlow_DataType:$output_type + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TestDynamicSourceOp : OneFlow_BaseOp<"TestDynamicSource", [NoSideEffect, DeclareOpInterfaceMethods]> { + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_output_arg_modify_fn = 1; +} + +def OneFlow_TestListDataTypeAndListShapeAndListStringAttrOp : OneFlow_BaseOp<"TestListDataTypeAndListShapeAndListStringAttr", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + ShapeArrayAttr:$out_shapes, + DTArrayAttr:$out_types, + StrArrayAttr:$string_list + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TestMultiInputOp : OneFlow_BaseOp<"TestMultiInput", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x1, + AnyType:$x2 + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TestMultiInputGradOp : OneFlow_BaseOp<"TestMultiInputGrad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x1, + AnyType:$x2, + AnyType:$y_diff + ); + let output = (outs + AnyType:$x1_diff, + AnyType:$x2_diff + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TestMultiOutputOrderOp : OneFlow_BaseOp<"TestMultiOutputOrder", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out1, + AnyType:$out2 + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TestRandomSourceOp : OneFlow_BaseOp<"TestRandomSource", [NoSideEffect, DeclareOpInterfaceMethods]> { + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$seed + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TestReshapeOp : OneFlow_BaseOp<"TestReshape", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + ShapeAttr:$shape + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TestSourceOp : OneFlow_BaseOp<"TestSource", [NoSideEffect, DeclareOpInterfaceMethods]> { + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TestSourceMultiGpuFixedOutNumOp : OneFlow_BaseOp<"TestSourceMultiGpuFixedOutNum", [NoSideEffect, DeclareOpInterfaceMethods]> { + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$out_num + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CcreluOp : OneFlow_BaseOp<"ccrelu", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CcreluGradOp : OneFlow_BaseOp<"ccrelu_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$y, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CpuOnlyReluTestOp : OneFlow_BaseOp<"cpu_only_relu_test", [NoSideEffect, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TestUserOpAttrAutoTypeOp : OneFlow_BaseOp<"test_user_op_attr_auto_type", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$int1, + DefaultValuedAttr:$int2 + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_TEST_OP_DEFINITIONS + +// Group: TRIGONOMETRIC +// acos, acos_grad, acosh, acosh_grad, asin, asin_grad, asinh, asinh_grad, atan, atan2, atan2_x_grad, atan2_y_grad, atan_grad, atanh, atanh_grad, cos, cos_grad, cosh, cosh_grad, hardtanh, hardtanh_grad, sin, sin_grad, sinh, sinh_grad, tan, tan_grad, tanh, tanh_grad +// Total: 29 + +#ifdef GET_ONEFLOW_TRIGONOMETRIC_OP_DEFINITIONS + +def OneFlow_AcosOp : OneFlow_BaseOp<"acos", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AcosGradOp : OneFlow_BaseOp<"acos_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AcoshOp : OneFlow_BaseOp<"acosh", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AcoshGradOp : OneFlow_BaseOp<"acosh_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AsinOp : OneFlow_BaseOp<"asin", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AsinGradOp : OneFlow_BaseOp<"asin_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AsinhOp : OneFlow_BaseOp<"asinh", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AsinhGradOp : OneFlow_BaseOp<"asinh_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AtanOp : OneFlow_BaseOp<"atan", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_Atan2Op : OneFlow_BaseOp<"atan2", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y + ); + let output = (outs + AnyType:$z + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_Atan2XGradOp : OneFlow_BaseOp<"atan2_x_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y, + AnyType:$dz + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_Atan2YGradOp : OneFlow_BaseOp<"atan2_y_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$y, + AnyType:$dz + ); + let output = (outs + AnyType:$dy + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AtanGradOp : OneFlow_BaseOp<"atan_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AtanhOp : OneFlow_BaseOp<"atanh", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AtanhGradOp : OneFlow_BaseOp<"atanh_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CosOp : OneFlow_BaseOp<"cos", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CosGradOp : OneFlow_BaseOp<"cos_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CoshOp : OneFlow_BaseOp<"cosh", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CoshGradOp : OneFlow_BaseOp<"cosh_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_HardtanhOp : OneFlow_BaseOp<"hardtanh", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$min_val, + DefaultValuedAttr:$max_val + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_HardtanhGradOp : OneFlow_BaseOp<"hardtanh_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$y, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$min_val, + DefaultValuedAttr:$max_val + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SinOp : OneFlow_BaseOp<"sin", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SinGradOp : OneFlow_BaseOp<"sin_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SinhOp : OneFlow_BaseOp<"sinh", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SinhGradOp : OneFlow_BaseOp<"sinh_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TanOp : OneFlow_BaseOp<"tan", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TanGradOp : OneFlow_BaseOp<"tan_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TanhOp : OneFlow_BaseOp<"tanh", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TanhGradOp : OneFlow_BaseOp<"tanh_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x, + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_TRIGONOMETRIC_OP_DEFINITIONS + +// Group: UNARY +// acc, affine_grid, affine_grid_grad, bernoulli, cast, cast_to_static_shape, cast_to_tick, celu, copy, count_not_finite, diag, diagonal, elu, expand, expand_dims, flatten, flip, flip_grad, fold, gelu, hardsigmoid, hardswish, leaky_relu, log2, logical_not, mish, narrow, one_hot, pack, random_mask_like, repeat, roll, selu, sigmoid, silu, softsign, sort, square_sum, squeeze, transpose, tril, triu, unfold, unfold_tensor, unpack, zero_like +// Total: 46 + +#ifdef GET_ONEFLOW_UNARY_OP_DEFINITIONS + +def OneFlow_AccOp : OneFlow_BaseOp<"acc", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$max_acc_num + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_output_blob_time_shape_infer_fn = 1; +} + +def OneFlow_AffineGridOp : OneFlow_BaseOp<"affine_grid", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$theta + ); + let output = (outs + AnyType:$grid + ); + let attrs = (ins + ShapeAttr:$size, + DefaultValuedAttr:$align_corners + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_AffineGridGradOp : OneFlow_BaseOp<"affine_grid_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dgrid + ); + let output = (outs + AnyType:$dtheta + ); + let attrs = (ins + ShapeAttr:$size, + DefaultValuedAttr:$align_corners + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_BernoulliOp : OneFlow_BaseOp<"bernoulli", [NoSideEffect, NoGrad, CpuOnly, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$seed, + DefaultValuedAttr:$has_seed, + OneFlow_DataType:$dtype + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CastOp : OneFlow_BaseOp<"cast", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + OneFlow_DataType:$dtype + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CastToStaticShapeOp : OneFlow_BaseOp<"cast_to_static_shape", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input + ); + let output = (outs + AnyType:$output + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CastToTickOp : OneFlow_BaseOp<"cast_to_tick", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_nd_sbp_infer_fn = 1; +} + +def OneFlow_CeluOp : OneFlow_BaseOp<"celu", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$alpha + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_CopyOp : OneFlow_BaseOp<"copy", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + StrAttr:$device_type, + DefaultValuedAttr:$device_id + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_device_infer_fn = 1; +} + +def OneFlow_CountNotFiniteOp : OneFlow_BaseOp<"count_not_finite", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_DiagOp : OneFlow_BaseOp<"diag", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$diagonal + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_DiagonalOp : OneFlow_BaseOp<"diagonal", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$offset + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_EluOp : OneFlow_BaseOp<"elu", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$alpha + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ExpandOp : OneFlow_BaseOp<"expand", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + SI32ArrayAttr:$logical_in_shape, + SI32ArrayAttr:$logical_expand_shape + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_ExpandDimsOp : OneFlow_BaseOp<"expand_dims", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$axis + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FlattenOp : OneFlow_BaseOp<"flatten", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$start_dim, + DefaultValuedAttr:$end_dim + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FlipOp : OneFlow_BaseOp<"flip", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + SI32ArrayAttr:$dims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FlipGradOp : OneFlow_BaseOp<"flip_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + SI32ArrayAttr:$dims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_FoldOp : OneFlow_BaseOp<"fold", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + SI32ArrayAttr:$output_size, + SI32ArrayAttr:$kernel_size, + SI32ArrayAttr:$strides, + SI32ArrayAttr:$padding, + SI32ArrayAttr:$dilation_rate + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_GeluOp : OneFlow_BaseOp<"gelu", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_HardsigmoidOp : OneFlow_BaseOp<"hardsigmoid", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_HardswishOp : OneFlow_BaseOp<"hardswish", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LeakyReluOp : OneFlow_BaseOp<"leaky_relu", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + DefaultValuedAttr:$alpha + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_Log2Op : OneFlow_BaseOp<"log2", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_LogicalNotOp : OneFlow_BaseOp<"logical_not", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_MishOp : OneFlow_BaseOp<"mish", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_NarrowOp : OneFlow_BaseOp<"narrow", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$dim, + DefaultValuedAttr:$start, + DefaultValuedAttr:$length + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_OneHotOp : OneFlow_BaseOp<"one_hot", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$indices + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$depth, + DefaultValuedAttr:$floating_on_value, + DefaultValuedAttr:$integer_on_value, + DefaultValuedAttr:$floating_off_value, + DefaultValuedAttr:$integer_off_value, + OneFlow_DataType:$dtype + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_input_arg_modify_fn = 1; +} + +def OneFlow_PackOp : OneFlow_BaseOp<"pack", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$pack_num + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_output_blob_time_shape_infer_fn = 1; +} + +def OneFlow_RandomMaskLikeOp : OneFlow_BaseOp<"random_mask_like", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$like + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$rate, + DefaultValuedAttr:$seed + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_RepeatOp : OneFlow_BaseOp<"repeat", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$repeat_num + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_output_blob_time_shape_infer_fn = 1; +} + +def OneFlow_RollOp : OneFlow_BaseOp<"roll", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + SI32ArrayAttr:$shifts, + SI32ArrayAttr:$dims + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SeluOp : OneFlow_BaseOp<"selu", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SigmoidOp : OneFlow_BaseOp<"sigmoid", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SiluOp : OneFlow_BaseOp<"silu", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SoftsignOp : OneFlow_BaseOp<"softsign", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SortOp : OneFlow_BaseOp<"sort", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + StrAttr:$direction + ); + let has_check_fn = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SquareSumOp : OneFlow_BaseOp<"square_sum", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_SqueezeOp : OneFlow_BaseOp<"squeeze", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + SI32ArrayAttr:$axes + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TransposeOp : OneFlow_BaseOp<"transpose", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$input + ); + let output = (outs + AnyType:$output + ); + let attrs = (ins + SI32ArrayAttr:$perm + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TrilOp : OneFlow_BaseOp<"tril", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$diagonal, + DefaultValuedAttr:$floating_fill_value, + DefaultValuedAttr:$integer_fill_value, + DefaultValuedAttr:$is_floating_fill_value + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_TriuOp : OneFlow_BaseOp<"triu", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$diagonal + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UnfoldOp : OneFlow_BaseOp<"unfold", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + StrAttr:$data_format, + SI32ArrayAttr:$kernel_size, + SI32ArrayAttr:$padding, + SI32ArrayAttr:$strides, + SI32ArrayAttr:$dilation_rate + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UnfoldTensorOp : OneFlow_BaseOp<"unfold_tensor", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + DefaultValuedAttr:$dimension, + DefaultValuedAttr:$size, + DefaultValuedAttr:$step + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UnpackOp : OneFlow_BaseOp<"unpack", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$in + ); + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$unpack_num + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; + let has_output_blob_time_shape_infer_fn = 1; +} + +def OneFlow_ZeroLikeOp : OneFlow_BaseOp<"zero_like", [NoSideEffect, NoGrad, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$like + ); + let output = (outs + AnyType:$out + ); + let same_output_regst_num = 1; + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_UNARY_OP_DEFINITIONS + +// Group: UPSAMPLE +// upsample, upsample_bicubic_2d, upsample_bicubic_2d_grad, upsample_bilinear_2d, upsample_bilinear_2d_grad, upsample_grad, upsample_linear_1d, upsample_linear_1d_grad, upsample_nearest_1d, upsample_nearest_1d_grad, upsample_nearest_2d, upsample_nearest_2d_grad, upsample_nearest_3d, upsample_nearest_3d_grad, upsample_trilinear_3d, upsample_trilinear_3d_grad +// Total: 16 + +#ifdef GET_ONEFLOW_UPSAMPLE_OP_DEFINITIONS + +def OneFlow_UpsampleOp : OneFlow_BaseOp<"upsample", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$align_corners, + StrAttr:$data_format, + StrAttr:$interpolation + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleBicubic2DOp : OneFlow_BaseOp<"upsample_bicubic_2d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$align_corners, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleBicubic2DGradOp : OneFlow_BaseOp<"upsample_bicubic_2d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$x + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$align_corners, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleBilinear2DOp : OneFlow_BaseOp<"upsample_bilinear_2d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$align_corners, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleBilinear2DGradOp : OneFlow_BaseOp<"upsample_bilinear_2d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$x + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$align_corners, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleGradOp : OneFlow_BaseOp<"upsample_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$x + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$align_corners, + StrAttr:$data_format, + StrAttr:$interpolation + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleLinear1DOp : OneFlow_BaseOp<"upsample_linear_1d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + DefaultValuedAttr:$scale_factor, + DefaultValuedAttr:$align_corners, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleLinear1DGradOp : OneFlow_BaseOp<"upsample_linear_1d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$x + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$scale_factor, + DefaultValuedAttr:$align_corners, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleNearest1DOp : OneFlow_BaseOp<"upsample_nearest_1d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + DefaultValuedAttr:$scale_factor, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleNearest1DGradOp : OneFlow_BaseOp<"upsample_nearest_1d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$x + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$scale_factor, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleNearest2DOp : OneFlow_BaseOp<"upsample_nearest_2d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleNearest2DGradOp : OneFlow_BaseOp<"upsample_nearest_2d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$x + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleNearest3DOp : OneFlow_BaseOp<"upsample_nearest_3d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + DefaultValuedAttr:$depth_scale, + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleNearest3DGradOp : OneFlow_BaseOp<"upsample_nearest_3d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$x + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$depth_scale, + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleTrilinear3DOp : OneFlow_BaseOp<"upsample_trilinear_3d", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$x + ); + let output = (outs + AnyType:$y + ); + let attrs = (ins + DefaultValuedAttr:$depth_scale, + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$align_corners, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +def OneFlow_UpsampleTrilinear3DGradOp : OneFlow_BaseOp<"upsample_trilinear_3d_grad", [NoSideEffect, DeclareOpInterfaceMethods]> { + let input = (ins + AnyType:$dy, + AnyType:$x + ); + let output = (outs + AnyType:$dx + ); + let attrs = (ins + DefaultValuedAttr:$depth_scale, + DefaultValuedAttr:$height_scale, + DefaultValuedAttr:$width_scale, + DefaultValuedAttr:$align_corners, + StrAttr:$data_format + ); + let has_logical_tensor_desc_infer_fn = 1; + let has_physical_tensor_desc_infer_fn = 1; + let has_get_sbp_fn = 1; + let has_data_type_infer_fn = 1; +} + +#endif // GET_ONEFLOW_UPSAMPLE_OP_DEFINITIONS diff --git a/oneflow/ir/install-llvm.cmake b/oneflow/ir/install-llvm.cmake index 1747ba5d77f..dc16b468c74 100644 --- a/oneflow/ir/install-llvm.cmake +++ b/oneflow/ir/install-llvm.cmake @@ -38,6 +38,13 @@ if(NOT llvm_monorepo_POPULATED) -DLLVM_ENABLE_BINDINGS=OFF -DMLIR_ENABLE_CUDA_RUNNER=${WITH_MLIR_CUDA_CODEGEN} -DCMAKE_CUDA_COMPILER=${CMAKE_CUDA_COMPILER} + -DINJA_URL=${INJA_URL} + -DINJA_MD5=${INJA_MD5} + -DJSON_URL=${JSON_URL} + -DJSON_MD5=${JSON_MD5} + -DCMAKE_CUDA_COMPILER=${CMAKE_CUDA_COMPILER} + -DLLVM_EXTERNAL_PROJECTS=OneFlowTableGen + -DLLVM_EXTERNAL_ONEFLOWTABLEGEN_SOURCE_DIR=${CMAKE_SOURCE_DIR}/tools/oneflow-tblgen -G ${CMAKE_GENERATOR} WORKING_DIRECTORY ${llvm_monorepo_BINARY_DIR} RESULT_VARIABLE ret) @@ -46,24 +53,21 @@ if(NOT llvm_monorepo_POPULATED) endif() include(ProcessorCount) ProcessorCount(PROC_NUM) - execute_process(COMMAND "${CMAKE_COMMAND}" --build . -j${PROC_NUM} - WORKING_DIRECTORY ${llvm_monorepo_BINARY_DIR} - RESULT_VARIABLE ret - ) - if(ret EQUAL "1") - message( FATAL_ERROR "Bad exit status") + if(WITH_MLIR) + set(INSTALL_ALL "install") endif() - execute_process(COMMAND "${CMAKE_COMMAND}" --build . -j${PROC_NUM} --target install + execute_process(COMMAND "${CMAKE_COMMAND}" --build . -j${PROC_NUM} --target ${INSTALL_ALL} install-oneflow-tblgen install-mlir-headers WORKING_DIRECTORY ${llvm_monorepo_BINARY_DIR} RESULT_VARIABLE ret ) if(ret EQUAL "1") message( FATAL_ERROR "Bad exit status") endif() - set(LLVM_DIR ${LLVM_INSTALL_DIR}/lib/cmake/llvm) - set(MLIR_DIR ${LLVM_INSTALL_DIR}/lib/cmake/mlir) endif() +if (WITH_MLIR) +set(LLVM_DIR ${LLVM_INSTALL_DIR}/lib/cmake/llvm) +set(MLIR_DIR ${LLVM_INSTALL_DIR}/lib/cmake/mlir) find_package(MLIR REQUIRED CONFIG) message(STATUS "Using MLIRConfig.cmake in: ${MLIR_DIR}") @@ -78,3 +82,4 @@ include(AddLLVM) include(AddMLIR) include(HandleLLVMOptions) set(LLVM_EXTERNAL_LIT "${llvm_monorepo_BINARY_DIR}/bin/llvm-lit" CACHE STRING "" FORCE) +endif() diff --git a/oneflow/ir/lib/OneFlow/CMakeLists.txt b/oneflow/ir/lib/OneFlow/CMakeLists.txt index e090b742a3c..b9f5a0e0a59 100644 --- a/oneflow/ir/lib/OneFlow/CMakeLists.txt +++ b/oneflow/ir/lib/OneFlow/CMakeLists.txt @@ -33,7 +33,6 @@ oneflow_add_mlir_dialect_library(MLIROneFlow DEPENDS MLIROneFlowOpsIncGen prepare_oneflow_third_party - oneflow-gen-ods LINK_LIBS PUBLIC ${dialect_libs} diff --git a/oneflow/ir/llvm-in-tree.cmake b/oneflow/ir/llvm-in-tree.cmake index c1c74645111..5e8d2afc3f0 100644 --- a/oneflow/ir/llvm-in-tree.cmake +++ b/oneflow/ir/llvm-in-tree.cmake @@ -19,7 +19,7 @@ set(CMAKE_CXX_FLAGS_DEBUG "" CACHE STRING "" FORCE) set(CMAKE_CXX_FLAGS_RELEASE "" CACHE STRING "" FORCE) set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "" CACHE STRING "" FORCE) -set(CMAKE_INSTALL_PREFIX ${LLVM_INSTALL_DIR} CACHE STRING "") +set(CMAKE_INSTALL_PREFIX ${LLVM_INSTALL_DIR} CACHE STRING "" FORCE) set(LLVM_ENABLE_RTTI ON CACHE BOOL "turn this on to make it compatible with protobuf") set(LLVM_ENABLE_EH ON CACHE BOOL "turn this on to make it compatible with half (the library)") set(LLVM_BUILD_EXAMPLES OFF CACHE BOOL "") @@ -46,6 +46,7 @@ set(MLIR_INCLUDE_DIR ${LLVM_MAIN_SRC_DIR}/../mlir/include) set(MLIR_GENERATED_INCLUDE_DIR ${LLVM_BINARY_DIR}/tools/mlir/include) set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}") + set(llvm_monorepo_BINARY_DIR ${llvm_monorepo_BINARY_DIR}) install(TARGETS oneflow of_protoobj of_cfgobj of_functional_obj EXPORT oneflow DESTINATION lib) install(EXPORT oneflow DESTINATION lib/oneflow) diff --git a/oneflow/ir/oneflow-extension/extension.cpp b/oneflow/ir/oneflow-extension/extension.cpp index 466f1567fcf..56ae0ac2afd 100644 --- a/oneflow/ir/oneflow-extension/extension.cpp +++ b/oneflow/ir/oneflow-extension/extension.cpp @@ -40,8 +40,8 @@ namespace { REGISTER_USER_OP("mlir_jit") .Attr("mlir_assembly") - .InputWithMinimum("in", 0) - .OutputWithMinimum("out", 0) + .Input("in") + .Output("out") .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { // TODO: infer shape by extracting Ops from mlir_assembly CHECK_EQ(ctx->inputs().size(), 2); diff --git a/oneflow/ir/oneflow-gen-ods/oneflow-gen-ods.cpp b/oneflow/ir/oneflow-gen-ods/oneflow-gen-ods.cpp deleted file mode 100644 index a70d9745e8b..00000000000 --- a/oneflow/ir/oneflow-gen-ods/oneflow-gen-ods.cpp +++ /dev/null @@ -1,739 +0,0 @@ -/* -Copyright 2020 The OneFlow Authors. All rights reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -#include -#include "oneflow/core/framework/user_op_def.h" -#include "oneflow/core/framework/user_op_registry.h" -#include "oneflow/core/framework/user_op_registry_manager.h" -#include - -namespace { - -using K = std::string; -using V = ::oneflow::user_op::OpRegistryResult; -using ::oneflow::AttrType; -using ::oneflow::UserOpDef_ArgDef; - -// from llvm -std::string convertToCamelFromSnakeCase(const std::string& input, bool capitalizeFirst) { - if (input.empty()) return ""; - - std::string output; - output.reserve(input.size()); - - // Push the first character, capatilizing if necessary. - if (capitalizeFirst && std::islower(input.front())) - output.push_back(toupper(input.front())); - else - output.push_back(input.front()); - - // Walk the input converting any `*_[a-z]` snake case into `*[A-Z]` camelCase. - for (size_t pos = 1, e = input.size(); pos < e; ++pos) { - if (input[pos] == '_' && pos != (e - 1) && std::islower(input[pos + 1])) - output.push_back(toupper(input[++pos])); - else - output.push_back(input[pos]); - } - return output; -} - -std::string GetMLIRAttrTypeName(const AttrType& attr_type) { - if (attr_type == ::oneflow::kAtInt32) { - return "SI32Attr"; - } else if (attr_type == ::oneflow::kAtInt64) { - return "SI64Attr"; - } else if (attr_type == ::oneflow::kAtBool) { - return "BoolAttr"; - } else if (attr_type == ::oneflow::kAtFloat) { - return "F32Attr"; - } else if (attr_type == ::oneflow::kAtDouble) { - return "F64Attr"; - } else if (attr_type == ::oneflow::kAtString) { - return "StrAttr"; - } else if (attr_type == ::oneflow::kAtShape) { - return "ShapeAttr"; - } else if (attr_type == ::oneflow::kAtDataType) { - return "OneFlow_DataType"; - } else if (attr_type == ::oneflow::kAtListInt32) { - return "SI32ArrayAttr"; - } else if (attr_type == ::oneflow::kAtListInt64) { - return "SI64ArrayAttr"; - } else if (attr_type == ::oneflow::kAtListFloat) { - return "F32ArrayAttr"; - } else if (attr_type == ::oneflow::kAtListDataType) { - return "DTArrayAttr"; - } else if (attr_type == ::oneflow::kAtListShape) { - return "ShapeArrayAttr"; - } else if (attr_type == ::oneflow::kAtListString) { - return "StrArrayAttr"; - } else { - LOG(FATAL) << "fail to convert: " << attr_type; - return "failure"; - } -} - -template -std::string ToZeroNoTrailing(T f) { - std::string str = std::to_string(f); - str.erase(str.find_last_not_of('0') + 1, std::string::npos); - return str; -} - -std::string GetDefaultValue(const ::oneflow::AttrValue& attr_val) { - if (attr_val.has_at_string()) { - return "\\\"" + attr_val.at_string() + "\\\""; - } else if (attr_val.has_at_int32()) { - return std::to_string(attr_val.at_int32()); - } else if (attr_val.has_at_int64()) { - return std::to_string(attr_val.at_int64()); - } else if (attr_val.has_at_float()) { - return ToZeroNoTrailing(attr_val.at_float()); - } else if (attr_val.has_at_double()) { - return ToZeroNoTrailing(attr_val.at_double()); - } else if (attr_val.has_at_bool()) { - return attr_val.at_bool() ? "true" : "false"; - } else if (attr_val.has_at_list_int32()) { - std::string ret = "{"; - const auto& list = attr_val.at_list_int32().val(); - for (auto it = list.begin(); it != list.end(); ++it) { - ret += std::to_string(*it) + (std::next(it) == list.end() ? "" : ", "); - } - ret += "}"; - return ret; - } else if (attr_val.has_at_list_int64()) { - std::string ret = "{"; - const auto& list = attr_val.at_list_int64().val(); - for (auto it = list.begin(); it != list.end(); ++it) { - ret += std::to_string(*it) + (std::next(it) == list.end() ? "" : ", "); - } - ret += "}"; - return ret; - } else if (attr_val.has_at_list_float()) { - std::string ret = "{"; - const auto& list = attr_val.at_list_float().val(); - for (auto it = list.begin(); it != list.end(); ++it) { - ret += std::to_string(*it) + (std::next(it) == list.end() ? "" : ", "); - } - ret += "}"; - return ret; - } else if (attr_val.has_at_list_string()) { - std::string ret = "{"; - const auto& list = attr_val.at_list_string().val(); - for (auto it = list.begin(); it != list.end(); ++it) { - ret += "\"" + *it + "\"" + (std::next(it) == list.end() ? "" : ", "); - } - ret += "}"; - return ret; - } else if (attr_val.has_at_data_type()) { - return std::to_string(attr_val.at_data_type()); - } - LOG(FATAL) << "fail to convert value_case: " << attr_val.value_case() << "\n" - << attr_val.DebugString(); -} - -std::string GetMLIRAttrType(const ::oneflow::UserOpDef_AttrDef& attr_def) { - const AttrType& attr_type = attr_def.type(); - std::string name = GetMLIRAttrTypeName(attr_type); - auto is_default_supported = - attr_def.default_val().has_at_bool() || attr_def.default_val().has_at_int32() - || attr_def.default_val().has_at_int64() || attr_def.default_val().has_at_float() - || attr_def.default_val().has_at_double() - || (attr_def.default_val().has_at_string() && attr_def.default_val().at_string().size() > 0); - if (attr_def.has_default_val() && is_default_supported) { - name = - "DefaultValuedAttr<" + name + ", " + "\"" + GetDefaultValue(attr_def.default_val()) + "\">"; - } - return name; -} - -const std::set& GetIdempotentOps() { - static std::set ret{"abs", "ceil", "floor", "ones_like", "relu", "relu_grad", - "relu6", "rint", "round", "sign", "zeros_like"}; - return ret; -} -const std::set& GetInvolutionOps() { - static std::set ret{"reciprocal", "negative"}; - return ret; -} - -bool IsGradOp(const std::string& op_name) { return op_name.find("grad") != std::string::npos; } -const std::set& GetQuantizationOps() { - static std::set ret{"min_max_observer", "moving_average_min_max_observer", - "fake_quantization", "quantization"}; - return ret; -} - -const std::set& GetMathOps() { - static std::set ret{"abs", "acos", - "acosh", "asin", - "asinh", "atan", - "atanh", "ceil", - "cos", "cosh", - "erf", "erfc", - "exp", "expm1", - "floor", "lgamma", - "log", "log1p", - "log_sigmoid", "negative", - "reciprocal", "reciprocal_no_nan", - "rint", "round", - "rsqrt", "sigmoid_v2", - "sign", "sin", - "sinh", "softplus", - "sqrt", "square", - "tan", "tanh"}; - return ret; -} - -const std::set& GetOpsUsedInPatterns() { - static std::set ret{"scalar_mul_by_tensor", "cast", "tril", "scalar_mul", - "fused_scale_tril", "dropout", "bias_add"}; - return ret; -} -bool IsMathOp(const std::string& op_name) { - bool is_grad = false; - for (const auto& name : GetMathOps()) { - if (op_name.find(name) != std::string::npos && IsGradOp(op_name)) { is_grad = true; } - } - return GetMathOps().find(op_name) != GetMathOps().end() || is_grad; -} -bool IsUsedInPatterns(const std::string& op_name) { - return GetOpsUsedInPatterns().find(op_name) != GetOpsUsedInPatterns().end(); -} -bool IsInvolutionOp(const std::string& op_name) { - return GetInvolutionOps().find(op_name) != GetInvolutionOps().end() && !IsGradOp(op_name); -} -bool IsQuantizationOp(const std::string& op_name) { - return GetQuantizationOps().find(op_name) != GetQuantizationOps().end(); -} -bool IsIdempotentOp(const std::string& op_name) { - return GetIdempotentOps().find(op_name) != GetIdempotentOps().end() && !IsGradOp(op_name); -} - -bool IsPoolOp(const std::string& op_name) { - return ((op_name.rfind("avg", 0) == 0 || op_name.rfind("max", 0) == 0) - || ((op_name.find("avg") != std::string::npos || op_name.find("max") != std::string::npos) - && op_name.rfind("tf", 0) == 0)) - && op_name.find("pool") != std::string::npos; -} -bool IsEagerOp(const std::string& op_name) { return (op_name.rfind("eager", 0) == 0); } -bool IsTensorBufferOp(const std::string& op_name) { - return op_name.find("tensor_buffer") != std::string::npos; -} -bool IsSummaryOp(const std::string& op_name) { - return op_name.find("summary") != std::string::npos; -} -bool IsAnyPoolOp(const std::string& op_name) { return op_name.find("pool") != std::string::npos; } -bool IsAnyConvOp(const std::string& op_name) { return op_name.find("conv") != std::string::npos; } -bool IsConvOp(const std::string& op_name) { - return op_name.rfind("conv", 0) == 0 && op_name.find("grad") == std::string::npos; -} - -bool IsLazyPoolOp(const std::string& op_name) { - return op_name.find("_pool") != std::string::npos && op_name.find("tf_") != std::string::npos; -} - -bool IsMaxPoolOp(const std::string& op_name) { - return op_name.find("maxpool") != std::string::npos; -} - -bool IsAvgPoolOp(const std::string& op_name) { - return op_name.find("avgpool") != std::string::npos; -} - -bool IsAdaptivePoolOp(const std::string& op_name) { - return op_name.find("_pool") != std::string::npos - && op_name.find("adaptive_") != std::string::npos; -} -bool IsNCCLOp(const std::string& op_name) { return op_name.find("nccl") != std::string::npos; } -bool IsOptimizerOp(const std::string& op_name) { - return (op_name.find("update") != std::string::npos || op_name.find("adam") != std::string::npos) - && op_name.find("scatter") == std::string::npos; -} -bool IsTrigonometric(const std::string& op_name) { - return (op_name.find("sin") != std::string::npos || op_name.find("cos") != std::string::npos - || op_name.find("tan") != std::string::npos) - && op_name.find("constant") == std::string::npos; -} -bool IsTestOp(const std::string& op_name) { - return (op_name.find("test") != std::string::npos || op_name.find("Test") != std::string::npos - || op_name.find("ccrelu") != std::string::npos); -} -bool IsPaddingOp(const std::string& op_name) { return (op_name.find("pad") != std::string::npos); } -bool IsAssignOp(const std::string& op_name) { - return (op_name.find("assign") != std::string::npos); -} -bool IsCrossEntropyOp(const std::string& op_name) { - return (op_name.find("cross_entropy") != std::string::npos); -} -bool IsCUDAOp(const std::string& op_name) { return (op_name.find("nvtx") != std::string::npos); } -bool IsMatmulOp(const std::string& op_name) { - return (op_name.find("matmul") != std::string::npos || op_name.find("fc") != std::string::npos); -} - -bool IsDatasetOp(const std::string& op_name) { - return (op_name.find("reader") != std::string::npos || op_name.find("Reader") != std::string::npos - || op_name.find("loader") != std::string::npos - || op_name.find("decoder") != std::string::npos); -} -bool IsUpsampleOp(const std::string& op_name) { - return (op_name.find("upsample") != std::string::npos); -} -bool IsBroadcastOp(const std::string& op_name) { - return (op_name.find("broadcast") != std::string::npos); -} -bool IsIdentityOp(const std::string& op_name) { - return (op_name.find("identity") != std::string::npos); -} -bool IsScalarOp(const std::string& op_name) { - return (op_name.rfind("scalar_", 0) == 0 || op_name.find("by_scalar") != std::string::npos); -} -bool IsImageOp(const std::string& op_name) { return (op_name.find("image") != std::string::npos); } -bool IsSoftmaxOp(const std::string& op_name) { - return (op_name.find("softmax") != std::string::npos); -} -bool IsFusedOp(const std::string& op_name) { - return (op_name.find("fused") != std::string::npos - || op_name.find("add_relu") != std::string::npos); -} -bool IsReduceOp(const std::string& op_name) { - return (op_name.find("reduce") != std::string::npos); -} -bool IsReshapeOp(const std::string& op_name) { - return (op_name.find("reshape") != std::string::npos); -} -bool IsLossOp(const std::string& op_name) { return (op_name.find("loss") != std::string::npos); } -bool IsDetectionOp(const std::string& op_name) { - return (op_name.find("top_k") != std::string::npos || op_name.find("bbox") != std::string::npos - || op_name.find("segmentation") != std::string::npos - || op_name.find("roi") != std::string::npos || op_name.find("poly") != std::string::npos - || op_name.find("nms") != std::string::npos - || op_name.find("object") != std::string::npos); -} -bool IsIndicesOp(const std::string& op_name) { - return (op_name.find("arg") != std::string::npos || op_name.find("where") != std::string::npos - || op_name.find("gather") != std::string::npos - || op_name.find("slice") != std::string::npos - || op_name.find("indices") != std::string::npos - || op_name.find("segment_sum") != std::string::npos - || op_name.find("scatter") != std::string::npos); -} -bool IsNormalizationOp(const std::string& op_name) { - return (op_name.find("norm") != std::string::npos); -} -bool IsParallelCastOp(const std::string& op_name) { - return (op_name.find("parallel_cast") != std::string::npos); -} - -std::string PostProcessClassName(const std::string& op_name) { - std::string ret = op_name; - ret = std::regex_replace(ret, std::regex("pool"), "Pool"); - ret = std::regex_replace(ret, std::regex("_1d"), "1D"); - ret = std::regex_replace(ret, std::regex("_2d"), "2D"); - ret = std::regex_replace(ret, std::regex("_3d"), "3D"); - ret = std::regex_replace(ret, std::regex("1d"), "1D"); - ret = std::regex_replace(ret, std::regex("2d"), "2D"); - ret = std::regex_replace(ret, std::regex("3d"), "3D"); - return ret; -} - -std::string GetConvOpClassName(const std::string& op_name) { - std::string ret(convertToCamelFromSnakeCase(op_name, true)); - // NOTE: should change form conv => Convolution ? - return ret; -} - -std::string GetBaseOp(const std::string& op_name) { - if (IsInvolutionOp(op_name)) { - return "OneFlow_InvolutionBaseOp"; - } else if (IsIdempotentOp(op_name)) { - return "OneFlow_IdempotentBaseOp"; - } else if (IsConvOp(op_name)) { - return "OneFlow_ConvolutionBaseOp"; - } else if (IsPoolOp(op_name)) { - if (IsLazyPoolOp(op_name)) { - return "OneFlow_" + std::string("TFPool") + std::string(IsGradOp(op_name) ? "Grad" : "") - + "BaseOp"; - } else { - return "OneFlow_" + std::string(IsMaxPoolOp(op_name) ? "Max" : "") - + std::string(IsAvgPoolOp(op_name) ? "Avg" : "") + "Pool" - + std::string(IsGradOp(op_name) ? "Grad" : "") + "BaseOp"; - } - } else if (IsAdaptivePoolOp(op_name)) { - return "OneFlow_AdaptivePool" + std::string(IsGradOp(op_name) ? "Grad" : "") + "BaseOp"; - } else { - return "OneFlow_BaseOp"; - } -} - -bool ShouldSkipOperandAndResultsAndAttrs(const std::string& op_name) { - return IsInvolutionOp(op_name) || IsIdempotentOp(op_name); -} - -bool ShouldGenEmptyBody(const std::string& op_name) { - return IsPoolOp(op_name) || IsAdaptivePoolOp(op_name) || IsConvOp(op_name); -} - -void PrintArgDef(const UserOpDef_ArgDef& arg_def) { - std::cout << " "; - if (arg_def.is_optional()) { std::cout << "Optional<"; } - if (arg_def.num_as_min()) { std::cout << "Variadic<"; } - std::cout << "AnyType"; - if (arg_def.is_optional() || arg_def.num_as_min()) { std::cout << ">"; } - CHECK(!(arg_def.is_optional() && arg_def.num_as_min())) << arg_def.DebugString(); - std::cout << ":$" << arg_def.name(); - if (arg_def.num_as_min()) { - // TODO: add verifier - } -} - -uint32_t NumMultipleVariadic( - const ::google::protobuf::RepeatedPtrField<::oneflow::UserOpDef_ArgDef>& arg_defs) { - uint32_t num_variadic_op = 0; - for (const auto& arg_def : arg_defs) { - if (arg_def.is_optional()) { num_variadic_op += 1; } - if (arg_def.num_as_min()) { num_variadic_op += 1; } - } - return num_variadic_op; -} - -bool HasAtLeastTwoVariadic( - const ::google::protobuf::RepeatedPtrField<::oneflow::UserOpDef_ArgDef>& arg_defs) { - return NumMultipleVariadic(arg_defs) > 1; -} - -bool HasVariadic( - const ::google::protobuf::RepeatedPtrField<::oneflow::UserOpDef_ArgDef>& arg_defs) { - return NumMultipleVariadic(arg_defs) > 0; -} - -std::string GetOperandKeys( - const ::google::protobuf::RepeatedPtrField<::oneflow::UserOpDef_ArgDef>& arg_defs) { - std::string ret = "{"; - for (auto it = arg_defs.begin(); it != arg_defs.end(); ++it) { - ret += ("\"" + it->name() + "\""); - if (std::next(it) != arg_defs.end()) { ret += ", "; } - } - ret += "}"; - return ret; -} - -std::string GetOperandMinimums( - const ::google::protobuf::RepeatedPtrField<::oneflow::UserOpDef_ArgDef>& arg_defs) { - std::string ret = "{"; - for (auto it = arg_defs.begin(); it != arg_defs.end(); ++it) { - uint32_t min = 0; - if (it->is_optional()) { - min = 0; - } else if (it->has_num_as_min()) { - min = it->num(); - } else { - min = 1; - } - ret += std::to_string(min); - if (std::next(it) != arg_defs.end()) { ret += ", "; } - } - ret += "}"; - return ret; -} - -// TODO: use MLIR Interfaces it implement this -void PrintReturnStaticVal(const std::string& type, const std::string& func_name, - const std::string& val) { - std::cout << " static const " + type + "* " + func_name + "() { static " + type + " val(" + val - + "); return &val; }\n"; -} -void PrintExtraClassDeclaration(const ::oneflow::UserOpDef& op_def) { - return; - std::cout << " let extraClassDeclaration = [{" - << "\n"; - PrintReturnStaticVal("std::vector", "inputKeys", GetOperandKeys(op_def.input())); - PrintReturnStaticVal("std::vector", "inputMinimums", - GetOperandMinimums(op_def.input())); - PrintReturnStaticVal("std::vector", "outputKeys", GetOperandKeys(op_def.output())); - PrintReturnStaticVal("std::vector", "outputMinimums", - GetOperandMinimums(op_def.input())); - std::cout << " }];" - << "\n"; -} - -void PrintHasCanonicalizer(const std::string& op_name) { - if (op_name == "add_n") { - std::cout << " let hasCanonicalizer = 1;" - << "\n"; - } -} - -void PrintTraitAttrs(const ::oneflow::UserOpDef& op_def) { - const bool need_operand_segment_sizes = HasAtLeastTwoVariadic(op_def.input()); - const bool need_result_segment_sizes = HasAtLeastTwoVariadic(op_def.output()); - if (need_operand_segment_sizes || need_result_segment_sizes) { - std::cout << " let trait_attrs = (ins" - << "\n"; - if (need_operand_segment_sizes) { - std::cout << " I32ElementsAttr:$operand_segment_sizes" - << (need_result_segment_sizes ? ",\n" : "\n"); - } - if (need_result_segment_sizes) { std::cout << " I32ElementsAttr:$result_segment_sizes\n"; } - std::cout << " );" - << "\n"; - } -} - -bool IsUnaryOp(const ::oneflow::user_op::OpRegistryResult& r) { - return NumMultipleVariadic(r.op_def.input()) == 0 && NumMultipleVariadic(r.op_def.output()) == 0 - && r.op_def.input().size() == 1 && r.op_def.output().size() == 1; -} - -bool IsBinaryOp(const ::oneflow::user_op::OpRegistryResult& r) { - return NumMultipleVariadic(r.op_def.input()) == 0 && NumMultipleVariadic(r.op_def.output()) == 0 - && r.op_def.input().size() == 2 && r.op_def.output().size() == 1; -} - -void PrintBody(const ::oneflow::user_op::OpRegistryResult& r) { - const ::oneflow::UserOpDef& op_def = r.op_def; - // TODO: handle in out size/optional - // TODO: handle "," in last element - std::cout << "{" - << "\n"; - // inputs - const bool should_skip_operand_and_results_and_attrs = - ShouldSkipOperandAndResultsAndAttrs(r.op_type_name); - const bool should_skip_operand = should_skip_operand_and_results_and_attrs; - const bool should_skip_result = should_skip_operand_and_results_and_attrs; - const bool should_skip_attrs = should_skip_operand_and_results_and_attrs; - if (op_def.input().size() && !should_skip_operand) { - std::cout << " let input = (ins" - << "\n"; - for (auto it = op_def.input().begin(); it != op_def.input().end(); ++it) { - PrintArgDef(*it); - std::cout << (std::next(it) == op_def.input().end() ? "" : ",") << "\n"; - } - std::cout << " );" - << "\n"; - } - // outputs - if (op_def.output().size() && !should_skip_result) { - std::cout << " let output = (outs" - << "\n"; - for (auto it = op_def.output().begin(); it != op_def.output().end(); ++it) { - PrintArgDef(*it); - std::cout << (std::next(it) == op_def.output().end() ? "" : ",") << "\n"; - } - std::cout << " );" - << "\n"; - } - // attrs - if (op_def.attr().size() && !should_skip_attrs) { - std::cout << " let attrs = (ins" - << "\n"; - for (auto it = op_def.attr().begin(); it != op_def.attr().end(); ++it) { - std::cout << " " << GetMLIRAttrType(*it) << ":$" << it->name() - << (std::next(it) == op_def.attr().end() ? "" : ",") << "\n"; - } - std::cout << " );" - << "\n"; - } - // trait attrs - PrintTraitAttrs(op_def); - PrintExtraClassDeclaration(op_def); - PrintHasCanonicalizer(r.op_type_name); - std::cout << "}" - << "\n"; -} - -bool ShouldGenBaseClass(const std::string& op_name) { return op_name == "normalization_add_relu"; } - -bool HasSideEffect(const std::string& op_name) { - return IsAssignOp(op_name) || IsOptimizerOp(op_name); -} - -std::string GetOpClassName(const std::string& op_name) { - std::string ret = ""; - if (IsConvOp(op_name)) { - ret = GetConvOpClassName(op_name); - } else { - ret = convertToCamelFromSnakeCase(op_name, true); - } - if (ShouldGenBaseClass(op_name)) { ret += "Base"; } - return PostProcessClassName(ret); -} - -std::string GetTraits(const ::oneflow::user_op::OpRegistryResult& r) { - const ::oneflow::UserOpDef& op_def = r.op_def; - std::string ret{}; - if (HasSideEffect(r.op_type_name) == false) { ret += "NoSideEffect"; } - const bool need_operand_segment_sizes = HasAtLeastTwoVariadic(op_def.input()); - const bool need_result_segment_sizes = HasAtLeastTwoVariadic(op_def.output()); - if (need_operand_segment_sizes) { - if (ret != "") ret += ", "; - ret += "AttrSizedOperandSegments"; - } - - if (need_result_segment_sizes) { - if (ret != "") ret += ", "; - ret += "AttrSizedResultSegments"; - } - if (ret != "") ret += ", "; - ret += "DeclareOpInterfaceMethods"; - return ret; -} - -bool IsReferencedByOtherDefinitions(const std::string& op_name) { - return ShouldGenBaseClass(op_name); -} - -bool ShoudSkipOp(const std::string& op_name) { return op_name == "mlir_jit"; } - -void PrintODSFromOpRegistryResults(const std::map& results) { - for (const auto& kv : results) { - if (ShoudSkipOp(kv.first)) continue; - const ::oneflow::user_op::OpRegistryResult& r = kv.second; - auto op_class_name = GetOpClassName(kv.first); - std::cout << (ShouldGenBaseClass(r.op_type_name) ? "class" : "def") << " OneFlow_" - << op_class_name << "Op : " << GetBaseOp(r.op_type_name) << "<\"" << kv.first - << "\", [" + GetTraits(r) + "]> "; // TODO: add traits - if (ShouldGenEmptyBody(r.op_type_name)) { - std::cout << "{}\n"; - } else { - PrintBody(r); - } - std::cout << "\n"; - } -} - -void PrintNamesInResults(const std::map& results) { - std::cout << "// "; - for (auto it = results.begin(); it != results.end(); ++it) { - std::cout << it->first; - if (std::next(it) != results.end()) { std::cout << ", "; } - } - std::cout << "\n"; -} - -void PrintGroupNames(std::map>& groups) { - std::cout << "// "; - for (auto it = groups.begin(); it != groups.end(); ++it) { - if (ShoudSkipOp(it->first)) continue; - std::cout << it->first; - if (std::next(it) != groups.end()) { std::cout << ";"; } - } - std::cout << "\n\n"; -} - -void PrintIncludes(std::map>& groups) { - std::cout << "/*\n"; - for (auto it = groups.begin(); it != groups.end(); ++it) { - auto group_name = it->first; - if (group_name == "BASE") continue; - if (group_name == "TEST") continue; - std::transform(group_name.begin(), group_name.end(), group_name.begin(), ::tolower); - group_name += "_ops"; - std::cout << "#define GET_OP_LIST\n"; - std::cout << "#include \"OneFlow/OneFlow." << group_name << ".cpp.inc\"\n"; - if (std::next(it) != groups.end()) { std::cout << ",\n"; } - } - std::cout << "*/\n\n"; -} - -void GroupOpRegistryResults(const std::map& results, - std::map>& groups) { - for (const auto& kv : results) { - std::string group_name = "MISC"; - const ::oneflow::user_op::OpRegistryResult& r = kv.second; - if (IsUnaryOp(r)) { group_name = "Unary"; } - if (IsBinaryOp(r)) { group_name = "Binary"; } - if (IsImageOp(r.op_type_name)) { group_name = "Image"; } - if (IsMathOp(r.op_type_name)) { group_name = "math"; } - if (IsPaddingOp(r.op_type_name)) { group_name = "PADDING"; } - if (IsIndicesOp(r.op_type_name)) { group_name = "Indices"; } - if (IsBroadcastOp(r.op_type_name)) { group_name = "Broadcast"; } - if (IsScalarOp(r.op_type_name)) { group_name = "Scalar"; } - if (IsReduceOp(r.op_type_name)) { group_name = "reduce"; } - if (IsReshapeOp(r.op_type_name)) { group_name = "reshape"; } - if (IsLossOp(r.op_type_name)) { group_name = "loss"; } - if (IsNormalizationOp(r.op_type_name)) { group_name = "Normalization"; } - if (IsCrossEntropyOp(r.op_type_name)) { group_name = "Cross_Entropy"; } - if (IsSoftmaxOp(r.op_type_name)) { group_name = "Softmax"; } - if (IsNCCLOp(r.op_type_name)) { group_name = "NCCL"; } - if (IsAnyConvOp(r.op_type_name)) { group_name = "CONV"; } - if (IsAnyPoolOp(r.op_type_name)) { group_name = "POOL"; } - if (IsUpsampleOp(r.op_type_name)) { group_name = "UPSAMPLE"; } - if (IsAssignOp(r.op_type_name)) { group_name = "assign"; } - if (IsOptimizerOp(r.op_type_name)) { group_name = "OPTIMIZER"; } - if (IsTrigonometric(r.op_type_name)) { group_name = "TRIGONOMETRIC"; } - if (IsIdempotentOp(r.op_type_name)) { group_name = "IDEMPOTENT"; } - if (IsInvolutionOp(r.op_type_name)) { group_name = "INVOLUTION"; } - if (IsIdentityOp(r.op_type_name)) { group_name = "Identity"; } - if (IsFusedOp(r.op_type_name)) { group_name = "Fused"; } - if (IsEagerOp(r.op_type_name)) { group_name = "eager"; } - if (IsQuantizationOp(r.op_type_name)) { group_name = "QUANTIZATION"; } - if (IsDatasetOp(r.op_type_name)) { group_name = "DATASET"; } - if (IsMatmulOp(r.op_type_name)) { group_name = "matmul"; } - if (IsTensorBufferOp(r.op_type_name)) { group_name = "tensor_buffer"; } - if (IsTestOp(r.op_type_name)) { group_name = "TEST"; } - if (IsDetectionOp(r.op_type_name)) { group_name = "Detection"; } - if (IsSummaryOp(r.op_type_name)) { group_name = "summary"; } - if (IsCUDAOp(r.op_type_name)) { group_name = "cuda"; } - if (IsParallelCastOp(r.op_type_name)) { group_name = "parallel_cast"; } - if (ShouldGenBaseClass(r.op_type_name)) { group_name = "BASE"; } - // if (IsUsedInPatterns(r.op_type_name)) { group_name = "used_in_patterns"; } - std::transform(group_name.begin(), group_name.end(), group_name.begin(), ::toupper); - groups[group_name].insert({kv.first, kv.second}); - } -} - -} // namespace - -int main(int argc, char* argv[]) { - std::streambuf* coutBuf = std::cout.rdbuf(); - std::ofstream of("OneFlowUserOpGen.td"); - std::streambuf* fileBuf = of.rdbuf(); - std::cout.rdbuf(fileBuf); - - std::map sorted{}; - auto unordered = ::oneflow::user_op::UserOpRegistryMgr::Get().GetAllOpRegistryResults(); - std::transform(unordered.begin(), unordered.end(), std::inserter(sorted, sorted.end()), - [](const std::pair& p) { return p; }); - std::map> groups; - GroupOpRegistryResults(sorted, groups); - PrintGroupNames(groups); - PrintIncludes(groups); - // std::cout << "#ifndef ONEFLOW_USER_OP_GEN\n"; - // std::cout << "#define ONEFLOW_USER_OP_GEN\n\n"; - - for (const auto& kv : groups) { - auto group_name = kv.first; - auto results = kv.second; - std::cout << "// Group: " << group_name << "\n"; - PrintNamesInResults(results); - std::cout << "// " - << "Total: " << kv.second.size() << "\n\n"; - CHECK(kv.second.size()) << group_name; - auto get_group_by_name = "GET_ONEFLOW_" + group_name + "_OP_DEFINITIONS"; - auto group_def_name = "ONEFLOW_" + group_name + "_OPS"; - std::cout << "#ifdef " << get_group_by_name << "\n\n"; - // std::cout << "#ifndef " << group_def_name << "\n\n"; - // std::cout << "#define " << group_def_name << "\n\n"; - PrintODSFromOpRegistryResults(results); - // std::cout << "#endif // " << group_def_name << "\n\n"; - std::cout << "#endif // " << get_group_by_name << "\n\n"; - } - of.flush(); - of.close(); - - std::cout.rdbuf(coutBuf); - return 0; -} diff --git a/oneflow/ir/test/OneFlow/test_fuse_cast_scale.py b/oneflow/ir/test/OneFlow/test_fuse_cast_scale.py index 71d3c8a49ad..895407ef068 100644 --- a/oneflow/ir/test/OneFlow/test_fuse_cast_scale.py +++ b/oneflow/ir/test/OneFlow/test_fuse_cast_scale.py @@ -80,7 +80,7 @@ def FuseCastScaleJob( test_case.assertTrue(np.allclose(loss, x * scale)) -# CHECK: %0 = oneflow.mlir_jit +# CHECK: oneflow.mlir_jit if __name__ == "__main__": unittest.main() diff --git a/oneflow/user/kernels/relu_kernel.cpp b/oneflow/user/kernels/relu_kernel.cpp index bf9124fc50b..bbc03a052a4 100644 --- a/oneflow/user/kernels/relu_kernel.cpp +++ b/oneflow/user/kernels/relu_kernel.cpp @@ -21,8 +21,8 @@ namespace oneflow { template std::unique_ptr NewReluPrimitive(Context* ctx) { - const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("in", 0); - const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("out", 0); + const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("x", 0); + const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("y", 0); return ep::primitive::NewPrimitive( ctx->device_type(), ep::primitive::UnaryOp::kRelu, src->data_type(), dst->data_type()); } @@ -38,8 +38,8 @@ class ReluKernel final : public user_op::OpKernel, public user_op::CudaGraphSupp auto primitive = NewReluPrimitive(ctx); CHECK(primitive); - const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("in", 0); - user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("out", 0); + const user_op::Tensor* x = ctx->Tensor4ArgNameAndIndex("x", 0); + user_op::Tensor* y = ctx->Tensor4ArgNameAndIndex("y", 0); const int64_t elem_cnt = x->shape().elem_cnt(); if (elem_cnt != 0) { diff --git a/oneflow/user/kernels/test_kernels.cpp b/oneflow/user/kernels/test_kernels.cpp index b229c4de5bf..580d729bc2e 100644 --- a/oneflow/user/kernels/test_kernels.cpp +++ b/oneflow/user/kernels/test_kernels.cpp @@ -31,8 +31,8 @@ class ReluKernel final : public user_op::OpKernel { private: void Compute(user_op::KernelComputeContext* ctx) const override { - const user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex("in", 0); - user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); + const user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex("x", 0); + user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex("y", 0); user_op::Tensor* tmp = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); CHECK_NOTNULL(tmp); NewKernelUtil::Relu(ctx->stream(), in_blob->shape().elem_cnt(), @@ -64,7 +64,7 @@ REGISTER_USER_KERNEL("ccrelu") .SetInferTmpSizeFn([](user_op::InferContext*) { return 10; }) .SetInplaceProposalFn([](const user_op::InferContext&, user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { - OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, true)); + OF_RETURN_IF_ERROR(AddInplaceArgPairFn("y", 0, "x", 0, true)); return Maybe::Ok(); }); diff --git a/oneflow/user/ops/acc_op.cpp b/oneflow/user/ops/acc_op.cpp index 4c3188bb5f2..92df9df8f8e 100644 --- a/oneflow/user/ops/acc_op.cpp +++ b/oneflow/user/ops/acc_op.cpp @@ -14,56 +14,53 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ Maybe AccOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe AccOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe AccOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return AccOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe AccOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe AccOp::InferOutputBlobTimeShape( + user_op::InferOutputBlobTimeShapeFnContext* ctx) { + const int32_t max_acc_num = ctx->user_op_conf().attr("max_acc_num"); + const Shape& in_time_shape = ctx->TimeShape4InputArgNameAndIndex("in", 0); + DimVector time_shape_dim_vec = in_time_shape.dim_vec(); + CHECK_OR_RETURN(!time_shape_dim_vec.empty()); + if (time_shape_dim_vec.back() == max_acc_num) { + time_shape_dim_vec.pop_back(); + } else if (time_shape_dim_vec.back() % max_acc_num == 0) { + time_shape_dim_vec.back() /= max_acc_num; + } else { + const int64_t elem_cnt = in_time_shape.elem_cnt(); + time_shape_dim_vec.resize(1); + time_shape_dim_vec.back() = elem_cnt / max_acc_num; + } + *ctx->mut_output_blob_time_shape() = Shape(time_shape_dim_vec); + return Maybe::Ok(); +} -REGISTER_USER_OP("acc") - .Input("in") - .Output("out") - .Attr("max_acc_num") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetOutputBlobTimeShapeInferFn( - [](user_op::InferOutputBlobTimeShapeFnContext* ctx) -> Maybe { - const int32_t max_acc_num = ctx->user_op_conf().attr("max_acc_num"); - const Shape& in_time_shape = ctx->TimeShape4InputArgNameAndIndex("in", 0); - DimVector time_shape_dim_vec = in_time_shape.dim_vec(); - CHECK_OR_RETURN(!time_shape_dim_vec.empty()); - if (time_shape_dim_vec.back() == max_acc_num) { - time_shape_dim_vec.pop_back(); - } else if (time_shape_dim_vec.back() % max_acc_num == 0) { - time_shape_dim_vec.back() /= max_acc_num; - } else { - const int64_t elem_cnt = in_time_shape.elem_cnt(); - time_shape_dim_vec.resize(1); - time_shape_dim_vec.back() = elem_cnt / max_acc_num; - } - *ctx->mut_output_blob_time_shape() = Shape(time_shape_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +namespace { REGISTER_USER_OP_GRAD("acc").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { diff --git a/oneflow/user/ops/adaptive_pool_op.cpp b/oneflow/user/ops/adaptive_pool_op.cpp index 605453e9d45..ab2b083b6b9 100644 --- a/oneflow/user/ops/adaptive_pool_op.cpp +++ b/oneflow/user/ops/adaptive_pool_op.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -72,22 +73,51 @@ Maybe InferBWDataType(user_op::InferContext* ctx) { return Maybe::Ok(); } -REGISTER_USER_OP("adaptive_avg_pool1d") - .Input("x") - .Attr>("output_size") - .Output("y") - .SetTensorDescInferFn(InferFWTensorDesc) - .SetGetSbpFn(FwGetSbpFn) - .SetDataTypeInferFn(InferFWDataType); - -REGISTER_USER_OP("adaptive_avg_pool1d_grad") - .Input("x") - .Input("dy") - .Attr>("output_size") - .Output("dx") - .SetTensorDescInferFn(InferBWTensorDesc) - .SetGetSbpFn(BwGetSbpFn) - .SetDataTypeInferFn(InferBWDataType); +} // namespace + +#define DEF_ADAPTIVE_AVG_POOL_OP(op_class_name_prefix) \ + /* static */ Maybe op_class_name_prefix##Op::InferLogicalTensorDesc( \ + user_op::InferContext* ctx) { \ + return InferFWTensorDesc(ctx); \ + } \ + \ + /*static*/ Maybe op_class_name_prefix##Op::InferPhysicalTensorDesc( \ + user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + \ + /* static */ Maybe op_class_name_prefix##Op::GetSbp(user_op::SbpContext* ctx) { \ + return FwGetSbpFn(ctx); \ + } \ + \ + /* static */ Maybe op_class_name_prefix##Op::InferDataType(user_op::InferContext* ctx) { \ + return InferFWDataType(ctx); \ + } \ + \ + /* static */ Maybe op_class_name_prefix##GradOp::InferLogicalTensorDesc( \ + user_op::InferContext* ctx) { \ + return InferBWTensorDesc(ctx); \ + } \ + \ + /*static*/ Maybe op_class_name_prefix##GradOp::InferPhysicalTensorDesc( \ + user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + \ + /* static */ Maybe op_class_name_prefix##GradOp::GetSbp(user_op::SbpContext* ctx) { \ + return BwGetSbpFn(ctx); \ + } \ + \ + /* static */ Maybe op_class_name_prefix##GradOp::InferDataType( \ + user_op::InferContext* ctx) { \ + return InferBWDataType(ctx); \ + } + +DEF_ADAPTIVE_AVG_POOL_OP(AdaptiveAvgPool1D) +DEF_ADAPTIVE_AVG_POOL_OP(AdaptiveAvgPool2D) +DEF_ADAPTIVE_AVG_POOL_OP(AdaptiveAvgPool3D) + +#undef DEF_ADAPTIVE_AVG_POOL_OP REGISTER_USER_OP_GRAD("adaptive_avg_pool1d") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { @@ -107,23 +137,6 @@ REGISTER_USER_OP_GRAD("adaptive_avg_pool1d") return Maybe::Ok(); }); -REGISTER_USER_OP("adaptive_avg_pool2d") - .Input("x") - .Attr>("output_size") - .Output("y") - .SetTensorDescInferFn(InferFWTensorDesc) - .SetGetSbpFn(FwGetSbpFn) - .SetDataTypeInferFn(InferFWDataType); - -REGISTER_USER_OP("adaptive_avg_pool2d_grad") - .Input("x") - .Input("dy") - .Attr>("output_size") - .Output("dx") - .SetTensorDescInferFn(InferBWTensorDesc) - .SetGetSbpFn(BwGetSbpFn) - .SetDataTypeInferFn(InferBWDataType); - REGISTER_USER_OP_GRAD("adaptive_avg_pool2d") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { const auto adaptive_avg_pool2d_grad_op_name = ctx->FwOp().op_name() + "_grad"; @@ -142,23 +155,6 @@ REGISTER_USER_OP_GRAD("adaptive_avg_pool2d") return Maybe::Ok(); }); -REGISTER_USER_OP("adaptive_avg_pool3d") - .Input("x") - .Attr>("output_size") - .Output("y") - .SetTensorDescInferFn(InferFWTensorDesc) - .SetGetSbpFn(FwGetSbpFn) - .SetDataTypeInferFn(InferFWDataType); - -REGISTER_USER_OP("adaptive_avg_pool3d_grad") - .Input("x") - .Input("dy") - .Attr>("output_size") - .Output("dx") - .SetTensorDescInferFn(InferBWTensorDesc) - .SetGetSbpFn(BwGetSbpFn) - .SetDataTypeInferFn(InferBWDataType); - REGISTER_USER_OP_GRAD("adaptive_avg_pool3d") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { const auto adaptive_avg_pool3d_grad_op_name = ctx->FwOp().op_name() + "_grad"; @@ -177,6 +173,4 @@ REGISTER_USER_OP_GRAD("adaptive_avg_pool3d") return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/add_n_op.cpp b/oneflow/user/ops/add_n_op.cpp index 81f60e0f44b..d1c680f68bc 100644 --- a/oneflow/user/ops/add_n_op.cpp +++ b/oneflow/user/ops/add_n_op.cpp @@ -14,45 +14,55 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("add_n") - .InputWithMinimum("in", 2) - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto& in_0 = ctx->InputTensorDesc("in", 0); - auto* out = ctx->OutputTensorDesc("out", 0); - CHECK_NOTNULL_OR_RETURN(out); - for (const auto& pair : ctx->inputs()) { - const auto& cur_in = ctx->InputTensorDesc(pair.first, pair.second); - if (in_0.shape().NumAxes() > 0 && cur_in.shape().NumAxes() > 0) { - CHECK_EQ_OR_RETURN(in_0.shape(), cur_in.shape()); - } - } - *out->mut_shape() = in_0.shape(); - *out->mut_is_dynamic() = in_0.is_dynamic(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) { - int64_t num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape().NumAxes(); - for (int64_t i = 0; i < num_axes; ++i) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg("out", 0), i).Build(); - } - ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(user_op::OpArg("out", 0)).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto& in_0 = ctx->InputTensorDesc("in", 0); - auto* out = ctx->OutputTensorDesc("out", 0); - CHECK_NOTNULL_OR_RETURN(out); - for (const auto& pair : ctx->inputs()) { - const auto& cur_in = ctx->InputTensorDesc(pair.first, pair.second); - CHECK_EQ_OR_RETURN(in_0.data_type(), cur_in.data_type()); - } - *out->mut_data_type() = in_0.data_type(); - return Maybe::Ok(); - }); +/* static */ Maybe AddNOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const auto& in_0 = ctx->InputTensorDesc("in", 0); + auto* out = ctx->OutputTensorDesc("out", 0); + CHECK_NOTNULL_OR_RETURN(out); + for (const auto& pair : ctx->inputs()) { + const auto& cur_in = ctx->InputTensorDesc(pair.first, pair.second); + if (in_0.shape().NumAxes() > 0 && cur_in.shape().NumAxes() > 0) { + CHECK_EQ_OR_RETURN(in_0.shape(), cur_in.shape()); + } + } + *out->mut_shape() = in_0.shape(); + *out->mut_is_dynamic() = in_0.is_dynamic(); + return Maybe::Ok(); +} + +/*static*/ Maybe AddNOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe AddNOp::GetSbp(user_op::SbpContext* ctx) { + int64_t num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape().NumAxes(); + for (int64_t i = 0; i < num_axes; ++i) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg("out", 0), i).Build(); + } + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(user_op::OpArg("out", 0)).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe AddNOp::InferDataType(user_op::InferContext* ctx) { + const auto& in_0 = ctx->InputTensorDesc("in", 0); + auto* out = ctx->OutputTensorDesc("out", 0); + CHECK_NOTNULL_OR_RETURN(out); + for (const auto& pair : ctx->inputs()) { + const auto& cur_in = ctx->InputTensorDesc(pair.first, pair.second); + CHECK_EQ_OR_RETURN(in_0.data_type(), cur_in.data_type()); + } + *out->mut_data_type() = in_0.data_type(); + return Maybe::Ok(); +} + +/*static*/ Maybe AddNOp::CheckAttr(const user_op::UserOpDefWrapper&, + const user_op::UserOpConfWrapper& op_conf) { + CHECK_OR_RETURN(op_conf.input_size("in") >= 2); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("add_n").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/affine_grid_op.cpp b/oneflow/user/ops/affine_grid_op.cpp index 449cfe21560..6d042aa851c 100644 --- a/oneflow/user/ops/affine_grid_op.cpp +++ b/oneflow/user/ops/affine_grid_op.cpp @@ -14,13 +14,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { -Maybe CheckAttr(const user_op::UserOpDefWrapper& def, - const user_op::UserOpConfWrapper& conf) { +Maybe CheckAttr_(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { bool pass_checked = true; std::stringstream err; err << "Illegal value for " << conf.op_type_name() << " op " << conf.op_name() << ": "; @@ -44,89 +45,99 @@ Maybe CheckAttr(const user_op::UserOpDefWrapper& def, } // namespace -REGISTER_USER_OP("affine_grid") - .Input("theta") - .Output("grid") - .Attr("size") - .Attr("align_corners") - .SetCheckAttrFn(CheckAttr) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& theta = ctx->InputTensorDesc("theta", 0); - user_op::TensorDesc* grid = ctx->OutputTensorDesc("grid", 0); - const Shape& size = ctx->Attr("size"); - // Only support 2D or 3D affine grid with NCHW layout - // For 2D grid: theta = { N, 2, 3 }, - // size = { N, C, H, W } - // grid = { N, H, W, 2 } - // For 3D grid: theta = { N, 3, 4 }, - // size = { N, C, D, H, W } - // grid = { N, D, H, W, 3 } - bool is_2d_grid = true; - if (theta.shape().At(1) == 2) { - CHECK_EQ_OR_RETURN(theta.shape().At(2), 3) << "Theta shape MUST be (N, 2, 3) or (N, 3, 4)"; - CHECK_EQ_OR_RETURN(size.NumAxes(), 4) << "Dimension of size MUST be 4, when 2d affine grid"; - CHECK_EQ_OR_RETURN(theta.shape().At(0), size.At(0)) - << "Theta and size MUST have same batch dimension"; - is_2d_grid = true; - } else if (theta.shape().At(1) == 3) { - CHECK_EQ_OR_RETURN(theta.shape().At(2), 4) << "Theta shape MUST be (N, 2, 3) or (N, 3, 4)"; - CHECK_EQ_OR_RETURN(size.NumAxes(), 5) "Dimension of size MUST be 4, when 3d affine grid"; - CHECK_EQ_OR_RETURN(theta.shape().At(0), size.At(0)) - << "Theta and size MUST have same batch dimension"; - is_2d_grid = false; - } else { - CHECK_OR_RETURN(false) << "Theta MUST be 2D or 3D grid"; - } - *grid->mut_is_dynamic() = theta.is_dynamic(); - Shape& grid_shape = *grid->mut_shape(); - if (is_2d_grid) { - grid_shape = {size.At(0), size.At(2), size.At(3), 2}; - } else { - grid_shape = {size.At(0), size.At(2), size.At(3), size.At(4), 3}; - } - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("theta", 0), 0) - .Split(user_op::OpArg("grid", 0), 0) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("grid", 0) = ctx->InputDType("theta", 0); - return Maybe::Ok(); - }); +/* static */ Maybe AffineGridOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& theta = ctx->InputTensorDesc("theta", 0); + user_op::TensorDesc* grid = ctx->OutputTensorDesc("grid", 0); + const Shape& size = ctx->Attr("size"); + // Only support 2D or 3D affine grid with NCHW layout + // For 2D grid: theta = { N, 2, 3 }, + // size = { N, C, H, W } + // grid = { N, H, W, 2 } + // For 3D grid: theta = { N, 3, 4 }, + // size = { N, C, D, H, W } + // grid = { N, D, H, W, 3 } + bool is_2d_grid = true; + if (theta.shape().At(1) == 2) { + CHECK_EQ_OR_RETURN(theta.shape().At(2), 3) << "Theta shape MUST be (N, 2, 3) or (N, 3, 4)"; + CHECK_EQ_OR_RETURN(size.NumAxes(), 4) << "Dimension of size MUST be 4, when 2d affine grid"; + CHECK_EQ_OR_RETURN(theta.shape().At(0), size.At(0)) + << "Theta and size MUST have same batch dimension"; + is_2d_grid = true; + } else if (theta.shape().At(1) == 3) { + CHECK_EQ_OR_RETURN(theta.shape().At(2), 4) << "Theta shape MUST be (N, 2, 3) or (N, 3, 4)"; + CHECK_EQ_OR_RETURN(size.NumAxes(), 5) "Dimension of size MUST be 4, when 3d affine grid"; + CHECK_EQ_OR_RETURN(theta.shape().At(0), size.At(0)) + << "Theta and size MUST have same batch dimension"; + is_2d_grid = false; + } else { + CHECK_OR_RETURN(false) << "Theta MUST be 2D or 3D grid"; + } + *grid->mut_is_dynamic() = theta.is_dynamic(); + Shape& grid_shape = *grid->mut_shape(); + if (is_2d_grid) { + grid_shape = {size.At(0), size.At(2), size.At(3), 2}; + } else { + grid_shape = {size.At(0), size.At(2), size.At(3), size.At(4), 3}; + } + return Maybe::Ok(); +} -REGISTER_USER_OP("affine_grid_grad") - .Input("dgrid") - .Output("dtheta") - .Attr("size") - .Attr("align_corners") - .SetCheckAttrFn(CheckAttr) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& size = ctx->Attr("size"); - - if (size.NumAxes() == 4) { - *(ctx->OutputTensorDesc("dtheta", 0)->mut_shape()) = {size.At(0), 2, 3}; - } else if (size.NumAxes() == 5) { - *(ctx->OutputTensorDesc("dtheta", 0)->mut_shape()) = {size.At(0), 3, 4}; - } else { - CHECK_OR_RETURN(false) << "size MUST be 4D or 5D"; - } - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("dgrid", 0), 0) - .Split(user_op::OpArg("dtheta", 0), 0) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dtheta", 0) = ctx->InputDType("dgrid", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe AffineGridOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe AffineGridOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("theta", 0), 0) + .Split(user_op::OpArg("grid", 0), 0) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe AffineGridOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + return CheckAttr_(def, conf); +} + +/* static */ Maybe AffineGridOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("grid", 0) = ctx->InputDType("theta", 0); + return Maybe::Ok(); +} + +/* static */ Maybe AffineGridGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& size = ctx->Attr("size"); + + if (size.NumAxes() == 4) { + *(ctx->OutputTensorDesc("dtheta", 0)->mut_shape()) = {size.At(0), 2, 3}; + } else if (size.NumAxes() == 5) { + *(ctx->OutputTensorDesc("dtheta", 0)->mut_shape()) = {size.At(0), 3, 4}; + } else { + CHECK_OR_RETURN(false) << "size MUST be 4D or 5D"; + } + return Maybe::Ok(); +} + +/*static*/ Maybe AffineGridGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe AffineGridGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("dgrid", 0), 0) + .Split(user_op::OpArg("dtheta", 0), 0) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe AffineGridGradOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + return CheckAttr_(def, conf); +} + +/* static */ Maybe AffineGridGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dtheta", 0) = ctx->InputDType("dgrid", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("affine_grid") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/amp_white_identity_op.cpp b/oneflow/user/ops/amp_white_identity_op.cpp index 269b08c3ef7..46a90141d8d 100644 --- a/oneflow/user/ops/amp_white_identity_op.cpp +++ b/oneflow/user/ops/amp_white_identity_op.cpp @@ -14,35 +14,37 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/* static */ Maybe AmpWhiteIdentityOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + *out->mut_shape() = in.shape(); + *out->mut_is_dynamic() = in.is_dynamic(); + return Maybe::Ok(); +} -REGISTER_USER_OP("amp_white_identity") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - *out->mut_shape() = in.shape(); - *out->mut_is_dynamic() = in.is_dynamic(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) { - const auto& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - for (int i = 0; i < in.shape().NumAxes(); ++i) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - *out->mut_data_type() = in.data_type(); - return Maybe::Ok(); - }); +/*static*/ Maybe AmpWhiteIdentityOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe AmpWhiteIdentityOp::GetSbp(user_op::SbpContext* ctx) { + const auto& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + for (int i = 0; i < in.shape().NumAxes(); ++i) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe AmpWhiteIdentityOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + *out->mut_data_type() = in.data_type(); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("amp_white_identity") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -60,6 +62,4 @@ REGISTER_USER_OP_GRAD("amp_white_identity") return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/arange_op.cpp b/oneflow/user/ops/arange_op.cpp index ae79ba73a26..1d6bb3556ca 100644 --- a/oneflow/user/ops/arange_op.cpp +++ b/oneflow/user/ops/arange_op.cpp @@ -14,60 +14,58 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("arange") - .Output("out") - .Attr("integer_start") - .Attr("integer_delta") - .Attr("integer_limit") - .Attr("float_start") - .Attr("float_delta") - .Attr("float_limit") - .Attr("dtype") - .Attr>("nd_sbp") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - Shape* out_shape = ctx->OutputShape("out", 0); - DataType dtype = ctx->Attr("dtype"); - int64_t range_elem_cnt = 0; - if (IsIntegralDataType(dtype)) { - int64_t integer_delta = ctx->Attr("integer_delta"); - CHECK_NE_OR_RETURN(integer_delta, static_cast(0)) - << "RuntimeError: step must be nonzero. "; - int64_t integer_start = ctx->Attr("integer_start"); - int64_t integer_limit = ctx->Attr("integer_limit"); - // CHECK when limit > start, delta > 0; limit < start, delta < 0; - CHECK_GT_OR_RETURN((integer_limit - integer_start) / integer_delta, static_cast(0)) - << "RuntimeError: upper bound and larger bound inconsistent with step sign"; - range_elem_cnt = - std::ceil(static_cast(integer_limit - integer_start) / integer_delta); - } else { - double float_delta = ctx->Attr("float_delta"); - CHECK_NE_OR_RETURN(float_delta, static_cast(0.0)) - << "RuntimeError: step must be nonzero. "; - double float_start = ctx->Attr("float_start"); - double float_limit = ctx->Attr("float_limit"); - // CHECK when limit > start, delta > 0; limit < start, delta < 0; - // CHECK_GE For 0-Dim Tensor - CHECK_GE_OR_RETURN((float_limit - float_start) / float_delta, static_cast(0.0)) - << "RuntimeError: upper bound and larger bound inconsistent with step sign"; - range_elem_cnt = std::ceil(static_cast(float_limit - float_start) / float_delta); - } - *out_shape = Shape({range_elem_cnt}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - cfg::SbpParallel default_sbp; - default_sbp.mutable_broadcast_parallel(); - return user_op::InferNdSbp4SrcOp(ctx, default_sbp); - }); + +/* static */ Maybe ArangeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + Shape* out_shape = ctx->OutputShape("out", 0); + DataType dtype = ctx->Attr("dtype"); + int64_t range_elem_cnt = 0; + if (IsIntegralDataType(dtype)) { + int64_t integer_delta = ctx->Attr("integer_delta"); + CHECK_NE_OR_RETURN(integer_delta, static_cast(0)) + << "RuntimeError: step must be nonzero. "; + int64_t integer_start = ctx->Attr("integer_start"); + int64_t integer_limit = ctx->Attr("integer_limit"); + // CHECK when limit > start, delta > 0; limit < start, delta < 0; + CHECK_GT_OR_RETURN((integer_limit - integer_start) / integer_delta, static_cast(0)) + << "RuntimeError: upper bound and larger bound inconsistent with step sign"; + range_elem_cnt = std::ceil(static_cast(integer_limit - integer_start) / integer_delta); + } else { + double float_delta = ctx->Attr("float_delta"); + CHECK_NE_OR_RETURN(float_delta, static_cast(0.0)) + << "RuntimeError: step must be nonzero. "; + double float_start = ctx->Attr("float_start"); + double float_limit = ctx->Attr("float_limit"); + // CHECK when limit > start, delta > 0; limit < start, delta < 0; + // CHECK_GE For 0-Dim Tensor + CHECK_GE_OR_RETURN((float_limit - float_start) / float_delta, static_cast(0.0)) + << "RuntimeError: upper bound and larger bound inconsistent with step sign"; + range_elem_cnt = std::ceil(static_cast(float_limit - float_start) / float_delta); + } + *out_shape = Shape({range_elem_cnt}); + return Maybe::Ok(); +} + +/*static*/ Maybe ArangeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ArangeOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ArangeOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + cfg::SbpParallel default_sbp; + default_sbp.mutable_broadcast_parallel(); + return user_op::InferNdSbp4SrcOp(ctx, default_sbp); +} + +/* static */ Maybe ArangeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/arg_sort_op.cpp b/oneflow/user/ops/arg_sort_op.cpp index 1b7df445d72..7cc0b23ed45 100644 --- a/oneflow/user/ops/arg_sort_op.cpp +++ b/oneflow/user/ops/arg_sort_op.cpp @@ -14,35 +14,39 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("arg_sort") - .Input("in") - .Output("out") - .Attr("direction") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - // The current implementation can only do arg_sort in the last dimension and should use - // Broadcast (by default) instead of Split for that dimension - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - return Maybe::Ok(); - }) - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& op_def, - const user_op::UserOpConfWrapper& op_conf) -> Maybe { - const std::string& direction = op_conf.attr("direction"); - CHECK_OR_RETURN(direction == "ASCENDING" || direction == "DESCENDING"); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = DataType::kInt32; - return Maybe::Ok(); - }); +/* static */ Maybe ArgSortOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe ArgSortOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ArgSortOp::GetSbp(user_op::SbpContext* ctx) { + // The current implementation can only do arg_sort in the last dimension and should use + // Broadcast (by default) instead of Split for that dimension + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe ArgSortOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + const std::string& direction = conf.attr("direction"); + CHECK_OR_RETURN(direction == "ASCENDING" || direction == "DESCENDING"); + return Maybe::Ok(); +} + +/* static */ Maybe ArgSortOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = DataType::kInt32; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/arg_where_op.cpp b/oneflow/user/ops/arg_where_op.cpp index 18f545ade3e..3ce31486a50 100644 --- a/oneflow/user/ops/arg_where_op.cpp +++ b/oneflow/user/ops/arg_where_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -31,20 +32,25 @@ Maybe InferTensorDesc(user_op::InferContext* ctx) { } // namespace -REGISTER_NO_GRAD_USER_OP("argwhere") - .Input("input") - .Output("output") - .Output("output_size") - .Attr("dtype", DataType::kInt32) - .SetTensorDescInferFn(InferTensorDesc) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const DataType dtype = ctx->Attr("dtype"); - user_op::TensorDesc* output_desc = ctx->OutputTensorDesc("output", 0); - *output_desc->mut_data_type() = dtype; - user_op::TensorDesc* output_size_desc = ctx->OutputTensorDesc("output_size", 0); - *output_size_desc->mut_data_type() = dtype; - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/* static */ Maybe ArgwhereOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDesc(ctx); +} + +/*static*/ Maybe ArgwhereOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ArgwhereOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe ArgwhereOp::InferDataType(user_op::InferContext* ctx) { + const DataType dtype = ctx->Attr("dtype"); + user_op::TensorDesc* output_desc = ctx->OutputTensorDesc("output", 0); + *output_desc->mut_data_type() = dtype; + user_op::TensorDesc* output_size_desc = ctx->OutputTensorDesc("output_size", 0); + *output_size_desc->mut_data_type() = dtype; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/argmax_op.cpp b/oneflow/user/ops/argmax_op.cpp index e79105e8269..58c6581eb29 100644 --- a/oneflow/user/ops/argmax_op.cpp +++ b/oneflow/user/ops/argmax_op.cpp @@ -14,28 +14,32 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("argmax") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - auto dim_vec = ctx->InputShape("in", 0).dim_vec(); - dim_vec.pop_back(); - *ctx->OutputShape("out", 0) = Shape(std::move(dim_vec)); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = DataType::kInt64; - return Maybe::Ok(); - }); +/* static */ Maybe ArgmaxOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + auto dim_vec = ctx->InputShape("in", 0).dim_vec(); + dim_vec.pop_back(); + *ctx->OutputShape("out", 0) = Shape(std::move(dim_vec)); + return Maybe::Ok(); +} + +/*static*/ Maybe ArgmaxOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ArgmaxOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe ArgmaxOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = DataType::kInt64; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/assign_op.cpp b/oneflow/user/ops/assign_op.cpp index f54342c2cce..c2b296dbca7 100644 --- a/oneflow/user/ops/assign_op.cpp +++ b/oneflow/user/ops/assign_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -64,7 +65,7 @@ Maybe InputArgModifierFn(const user_op::GetInputArgModifier& GetInputArgMo return Maybe::Ok(); } -Maybe InferDataType(user_op::InferContext* ctx) { +Maybe InferDataType_(user_op::InferContext* ctx) { const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc("ref", 0); const user_op::TensorDesc& value_desc = ctx->InputTensorDesc("value", 0); CHECK_OR_RETURN(ref_desc.data_type() == value_desc.data_type()); @@ -77,30 +78,32 @@ Maybe InferDataType(user_op::InferContext* ctx) { } // namespace -REGISTER_NO_GRAD_USER_OP("assign") - .Input("ref") - .Input("value") - .SetTensorDescInferFn(InferTensorDesc) - .SetGetSbpFn(GetSbpSignatures) - .SetInputArgModifyFn(InputArgModifierFn) - .SetDataTypeInferFn(InferDataType); +#define DEF_ASSIGN_OP(op_class_name) \ + /* static */ Maybe op_class_name::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return InferTensorDesc(ctx); \ + } \ + \ + /*static*/ Maybe op_class_name::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + \ + /* static */ Maybe op_class_name::GetSbp(user_op::SbpContext* ctx) { \ + return GetSbpSignatures(ctx); \ + } \ + \ + /* static */ Maybe op_class_name::ModifyInputArg( \ + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { \ + return InputArgModifierFn(GetInputArgModifierFn, conf); \ + } \ + \ + /* static */ Maybe op_class_name::InferDataType(user_op::InferContext* ctx) { \ + return InferDataType_(ctx); \ + } -REGISTER_NO_GRAD_USER_OP("assign_if") - .Input("ref") - .Input("value") - .Input("condition") - .SetTensorDescInferFn(InferTensorDesc) - .SetGetSbpFn(GetSbpSignatures) - .SetInputArgModifyFn(InputArgModifierFn) - .SetDataTypeInferFn(InferDataType); +DEF_ASSIGN_OP(AssignUserOp) +DEF_ASSIGN_OP(AssignIfOp) +DEF_ASSIGN_OP(AssignIfNotOp) -REGISTER_NO_GRAD_USER_OP("assign_if_not") - .Input("ref") - .Input("value") - .Input("condition") - .SetTensorDescInferFn(InferTensorDesc) - .SetGetSbpFn(GetSbpSignatures) - .SetInputArgModifyFn(InputArgModifierFn) - .SetDataTypeInferFn(InferDataType); +#undef DEF_ASSIGN_OP } // namespace oneflow diff --git a/oneflow/user/ops/batch_gather_op.cpp b/oneflow/user/ops/batch_gather_op.cpp index 4f37bf102b0..f0581702ece 100644 --- a/oneflow/user/ops/batch_gather_op.cpp +++ b/oneflow/user/ops/batch_gather_op.cpp @@ -14,75 +14,79 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("batch_gather") - .Input("in") - .Input("indices") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - CHECK_GT_OR_RETURN(in.shape().NumAxes(), 0); - const user_op::TensorDesc& indices = ctx->InputTensorDesc("indices", 0); - CHECK_GT_OR_RETURN(indices.shape().NumAxes(), 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - CHECK_LE_OR_RETURN(indices.shape().dim_vec().size(), in.shape().dim_vec().size()); - FOR_RANGE(int64_t, i, 0, indices.shape().dim_vec().size() - 1) { - if (in.is_dynamic() && indices.is_dynamic() == false) { - CHECK_GE_OR_RETURN(indices.shape().dim_vec().at(i), in.shape().dim_vec().at(i)); - } else if (in.is_dynamic() == false && indices.is_dynamic()) { - UNIMPLEMENTED(); - } else { - CHECK_EQ_OR_RETURN(indices.shape().dim_vec().at(i), in.shape().dim_vec().at(i)); - } - } +/* static */ Maybe BatchGatherOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + CHECK_GT_OR_RETURN(in.shape().NumAxes(), 0); + const user_op::TensorDesc& indices = ctx->InputTensorDesc("indices", 0); + CHECK_GT_OR_RETURN(indices.shape().NumAxes(), 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + CHECK_LE_OR_RETURN(indices.shape().dim_vec().size(), in.shape().dim_vec().size()); + FOR_RANGE(int64_t, i, 0, indices.shape().dim_vec().size() - 1) { + if (in.is_dynamic() && indices.is_dynamic() == false) { + CHECK_GE_OR_RETURN(indices.shape().dim_vec().at(i), in.shape().dim_vec().at(i)); + } else if (in.is_dynamic() == false && indices.is_dynamic()) { + UNIMPLEMENTED(); + } else { + CHECK_EQ_OR_RETURN(indices.shape().dim_vec().at(i), in.shape().dim_vec().at(i)); + } + } - DimVector dim_vec(in.shape().dim_vec()); - dim_vec.at(indices.shape().NumAxes() - 1) = indices.shape().dim_vec().back(); - *out->mut_shape() = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK_OR_RETURN(indices_modifier != nullptr); - indices_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const int64_t indices_num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0).shape().NumAxes(); - if (indices_num_axes > 1) { - FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) { - ctx->NewBuilder() - .Split(user_op::OpArg("indices", 0), i) - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .Broadcast(user_op::OpArg("indices", 0)) - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - } else { - auto err = std::make_shared(); - err->set_msg("BatchGatherOp: indices_num_axes equals " + std::to_string(indices_num_axes) - + " (should be bigger than 1)."); - err->mutable_check_failed_error(); - return err; - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& indices = ctx->InputTensorDesc("indices", 0); - CHECK_OR_RETURN(IsIndexDataType(indices.data_type())); - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - *out->mut_data_type() = in.data_type(); - return Maybe::Ok(); - }); + DimVector dim_vec(in.shape().dim_vec()); + dim_vec.at(indices.shape().NumAxes() - 1) = indices.shape().dim_vec().back(); + *out->mut_shape() = Shape(dim_vec); + return Maybe::Ok(); +} + +/*static*/ Maybe BatchGatherOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe BatchGatherOp::GetSbp(user_op::SbpContext* ctx) { + const int64_t indices_num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0).shape().NumAxes(); + if (indices_num_axes > 1) { + FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) { + ctx->NewBuilder() + .Split(user_op::OpArg("indices", 0), i) + .Split(user_op::OpArg("in", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + ctx->NewBuilder() + .Broadcast(user_op::OpArg("indices", 0)) + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + } else { + auto err = std::make_shared(); + err->set_msg("BatchGatherOp: indices_num_axes equals " + std::to_string(indices_num_axes) + + " (should be bigger than 1)."); + err->mutable_check_failed_error(); + return err; + } + return Maybe::Ok(); +} + +/* static */ Maybe BatchGatherOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); + CHECK_OR_RETURN(indices_modifier != nullptr); + indices_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe BatchGatherOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& indices = ctx->InputTensorDesc("indices", 0); + CHECK_OR_RETURN(IsIndexDataType(indices.data_type())); + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + *out->mut_data_type() = in.data_type(); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("batch_gather") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/bernoulli_op.cpp b/oneflow/user/ops/bernoulli_op.cpp index 53f854e62f1..3068b83fd0c 100644 --- a/oneflow/user/ops/bernoulli_op.cpp +++ b/oneflow/user/ops/bernoulli_op.cpp @@ -14,32 +14,33 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("bernoulli") - .Input("in") - .Output("out") - .Attr("seed", -1) - .Attr("has_seed", false) - .Attr("dtype") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - *out_tensor->mut_shape() = in_tensor.shape(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - for (int i = 0; i < in_tensor.shape().NumAxes(); ++i) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - *out_tensor->mut_data_type() = ctx->Attr("dtype"); - return Maybe::Ok(); - }); +/* static */ Maybe BernoulliOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + *out_tensor->mut_shape() = in_tensor.shape(); + return Maybe::Ok(); +} + +/*static*/ Maybe BernoulliOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe BernoulliOp::GetSbp(user_op::SbpContext* ctx) { + const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + for (int i = 0; i < in_tensor.shape().NumAxes(); ++i) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe BernoulliOp::InferDataType(user_op::InferContext* ctx) { + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + *out_tensor->mut_data_type() = ctx->Attr("dtype"); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/bias_add_op.cpp b/oneflow/user/ops/bias_add_op.cpp index 963e38c3116..ba0c928804f 100644 --- a/oneflow/user/ops/bias_add_op.cpp +++ b/oneflow/user/ops/bias_add_op.cpp @@ -14,48 +14,50 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("bias_add") - .Input("a") - .Input("b") - .Output("out") - .Attr("axis") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); - const auto& b_tensor_desc = ctx->InputTensorDesc("b", 0); - const auto bias_add_axis = ctx->Attr("axis"); - CHECK_EQ_OR_RETURN(b_tensor_desc.shape().NumAxes(), 1); - CHECK_GE_OR_RETURN(bias_add_axis, 0); - CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); - CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); - *ctx->OutputShape("out", 0) = ctx->InputShape("a", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("a", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto axis = ctx->Attr("axis"); - for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape().NumAxes(); - ++i) { - if (i == axis) { continue; } - ctx->NewBuilder() - .Split(user_op::OpArg("a", 0), i) - .Broadcast(user_op::OpArg("b", 0)) - .Split(ctx->outputs(), i) - .Build(); - } - ctx->NewBuilder() - .Split(user_op::OpArg("b", 0), 0) - .Split(user_op::OpArg("a", 0), axis) - .Split(ctx->outputs(), axis) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("a", 0); - return Maybe::Ok(); - }); +/* static */ Maybe BiasAddOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); + const auto& b_tensor_desc = ctx->InputTensorDesc("b", 0); + const auto bias_add_axis = ctx->Attr("axis"); + CHECK_EQ_OR_RETURN(b_tensor_desc.shape().NumAxes(), 1); + CHECK_GE_OR_RETURN(bias_add_axis, 0); + CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); + CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); + *ctx->OutputShape("out", 0) = ctx->InputShape("a", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("a", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe BiasAddOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe BiasAddOp::GetSbp(user_op::SbpContext* ctx) { + const auto axis = ctx->Attr("axis"); + for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape().NumAxes(); + ++i) { + if (i == axis) { continue; } + ctx->NewBuilder() + .Split(user_op::OpArg("a", 0), i) + .Broadcast(user_op::OpArg("b", 0)) + .Split(ctx->outputs(), i) + .Build(); + } + ctx->NewBuilder() + .Split(user_op::OpArg("b", 0), 0) + .Split(user_op::OpArg("a", 0), axis) + .Split(ctx->outputs(), axis) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe BiasAddOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("a", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("bias_add") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/binary_cross_entropy_op.cpp b/oneflow/user/ops/binary_cross_entropy_op.cpp index b4bd3b76f74..0d328657660 100644 --- a/oneflow/user/ops/binary_cross_entropy_op.cpp +++ b/oneflow/user/ops/binary_cross_entropy_op.cpp @@ -16,10 +16,13 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/loss_op_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { + namespace { -Maybe InferTensorDescFn(user_op::InferContext* ctx) { + +Maybe InferTensorDescFn_(user_op::InferContext* ctx) { const auto& input_desc = ctx->InputTensorDesc("input", 0); const auto& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic()); @@ -37,7 +40,7 @@ Maybe InferTensorDescFn(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe InferDataType(user_op::InferContext* ctx) { +Maybe InferDataType_(user_op::InferContext* ctx) { const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type()); @@ -85,31 +88,47 @@ Maybe InferGradDataType(user_op::InferContext* ctx) { } } // namespace -REGISTER_USER_OP("binary_cross_entropy") - .Input("input") - .Input("target") - .OptionalInput("weight") - .Output("out") - .SetTensorDescInferFn(InferTensorDescFn) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); - CHECK_OR_RETURN(target_modifier != nullptr); - target_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetDataTypeInferFn(InferDataType) - .SetGetSbpFn(GenLossForwardDefaultGetSbpFn()); - -REGISTER_USER_OP("binary_cross_entropy_grad") - .Input("input") - .Input("target") - .OptionalInput("weight") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn(InferGradTensorDescFn) - .SetDataTypeInferFn(InferGradDataType) - .SetGetSbpFn(GenLossBackwardDefaultGetSbpFn()); +/* static */ Maybe BinaryCrossEntropyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDescFn_(ctx); +} + +/*static*/ Maybe BinaryCrossEntropyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe BinaryCrossEntropyOp::GetSbp(user_op::SbpContext* ctx) { + return GenLossForwardDefaultGetSbpFn()(ctx); +} + +/* static */ Maybe BinaryCrossEntropyOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); + CHECK_OR_RETURN(target_modifier != nullptr); + target_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe BinaryCrossEntropyOp::InferDataType(user_op::InferContext* ctx) { + return InferDataType_(ctx); +} + +/* static */ Maybe BinaryCrossEntropyGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return InferGradTensorDescFn(ctx); +} + +/*static*/ Maybe BinaryCrossEntropyGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe BinaryCrossEntropyGradOp::GetSbp(user_op::SbpContext* ctx) { + return GenLossBackwardDefaultGetSbpFn()(ctx); +} + +/* static */ Maybe BinaryCrossEntropyGradOp::InferDataType(user_op::InferContext* ctx) { + return InferGradDataType(ctx); +} REGISTER_USER_OP_GRAD("binary_cross_entropy") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/binary_cross_entropy_with_logits_op.cpp b/oneflow/user/ops/binary_cross_entropy_with_logits_op.cpp index 16c0edc6094..0a124525a60 100644 --- a/oneflow/user/ops/binary_cross_entropy_with_logits_op.cpp +++ b/oneflow/user/ops/binary_cross_entropy_with_logits_op.cpp @@ -16,6 +16,7 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/loss_op_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { @@ -42,7 +43,7 @@ Maybe InferTensorDescFn(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe InferDataType(user_op::InferContext* ctx) { +Maybe InferDataType_(user_op::InferContext* ctx) { const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type()); @@ -101,45 +102,60 @@ Maybe InferGradDataType(user_op::InferContext* ctx) { } } // namespace -REGISTER_USER_OP("binary_cross_entropy_with_logits") - .Input("input") - .Input("target") - .OptionalInput("weight") - .OptionalInput("pos_weight") - .Output("out") - .Attr("has_pos_weight") - .SetTensorDescInferFn(InferTensorDescFn) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); - CHECK_OR_RETURN(target_modifier != nullptr); - target_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetDataTypeInferFn(InferDataType) - .SetGetSbpFn(GenLossForwardDefaultGetSbpFn([](user_op::UserOpSbpSignatureBuilder& builder, - user_op::SbpContext* ctx) { - if (ctx->user_op_conf().has_input("pos_weight", 0)) { - builder.Broadcast(user_op::OpArg("pos_weight", 0)); - } - })); - -REGISTER_USER_OP("binary_cross_entropy_with_logits_grad") - .Input("input") - .Input("target") - .OptionalInput("weight") - .OptionalInput("pos_weight") - .Input("dy") - .Output("dx") - .Attr("has_pos_weight") - .SetTensorDescInferFn(InferGradTensorDescFn) - .SetDataTypeInferFn(InferGradDataType) - .SetGetSbpFn(GenLossBackwardDefaultGetSbpFn([](user_op::UserOpSbpSignatureBuilder& builder, - user_op::SbpContext* ctx) { - if (ctx->user_op_conf().has_input("pos_weight", 0)) { - builder.Broadcast(user_op::OpArg("pos_weight", 0)); - } - })); +/* static */ Maybe BinaryCrossEntropyWithLogitsOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return InferTensorDescFn(ctx); +} + +/*static*/ Maybe BinaryCrossEntropyWithLogitsOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe BinaryCrossEntropyWithLogitsOp::GetSbp(user_op::SbpContext* ctx) { + return GenLossForwardDefaultGetSbpFn( + [](user_op::UserOpSbpSignatureBuilder& builder, user_op::SbpContext* ctx) { + if (ctx->user_op_conf().has_input("pos_weight", 0)) { + builder.Broadcast(user_op::OpArg("pos_weight", 0)); + } + })(ctx); +} + +/* static */ Maybe BinaryCrossEntropyWithLogitsOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); + CHECK_OR_RETURN(target_modifier != nullptr); + target_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe BinaryCrossEntropyWithLogitsOp::InferDataType(user_op::InferContext* ctx) { + return InferDataType_(ctx); +} + +/* static */ Maybe BinaryCrossEntropyWithLogitsGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return InferGradTensorDescFn(ctx); +} + +/*static*/ Maybe BinaryCrossEntropyWithLogitsGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe BinaryCrossEntropyWithLogitsGradOp::GetSbp(user_op::SbpContext* ctx) { + return GenLossBackwardDefaultGetSbpFn( + [](user_op::UserOpSbpSignatureBuilder& builder, user_op::SbpContext* ctx) { + if (ctx->user_op_conf().has_input("pos_weight", 0)) { + builder.Broadcast(user_op::OpArg("pos_weight", 0)); + } + })(ctx); +} + +/* static */ Maybe BinaryCrossEntropyWithLogitsGradOp::InferDataType( + user_op::InferContext* ctx) { + return InferGradDataType(ctx); +} REGISTER_USER_OP_GRAD("binary_cross_entropy_with_logits") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/broadcast_div_grad_op.cpp b/oneflow/user/ops/broadcast_div_grad_op.cpp index 39add2276c7..8e1c16a2b2a 100644 --- a/oneflow/user/ops/broadcast_div_grad_op.cpp +++ b/oneflow/user/ops/broadcast_div_grad_op.cpp @@ -14,59 +14,61 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("broadcast_div_grad") - .Input("y") - .Input("z") - .Input("dz") - .Output("dy") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("dy", 0) = ctx->InputShape("y", 0); - *ctx->OutputIsDynamic("dy", 0) = ctx->InputIsDynamic("y", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const Shape& y_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0).shape(); - const Shape& z_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("z", 0).shape(); - CHECK_LE_OR_RETURN(y_shape.NumAxes(), z_shape.NumAxes()); - FOR_RANGE(int64_t, i, 0, y_shape.NumAxes()) { - const int64_t axis_y = y_shape.NumAxes() - 1 - i; - const int64_t axis_z = z_shape.NumAxes() - 1 - i; - if (y_shape.At(axis_y) == z_shape.At(axis_z)) { - ctx->NewBuilder() - .Split(user_op::OpArg("y", 0), axis_y) - .Split(user_op::OpArg("z", 0), axis_z) - .Split(user_op::OpArg("dz", 0), axis_z) - .Split(user_op::OpArg("dy", 0), axis_y) - .Build(); - } else { - ctx->NewBuilder() - .Broadcast(user_op::OpArg("y", 0)) - .Split(user_op::OpArg("z", 0), axis_z) - .Split(user_op::OpArg("dz", 0), axis_z) - .PartialSum(user_op::OpArg("dy", 0)) - .Build(); - } - } +/* static */ Maybe BroadcastDivGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("dy", 0) = ctx->InputShape("y", 0); + *ctx->OutputIsDynamic("dy", 0) = ctx->InputIsDynamic("y", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe BroadcastDivGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe BroadcastDivGradOp::GetSbp(user_op::SbpContext* ctx) { + const Shape& y_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0).shape(); + const Shape& z_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("z", 0).shape(); + CHECK_LE_OR_RETURN(y_shape.NumAxes(), z_shape.NumAxes()); + FOR_RANGE(int64_t, i, 0, y_shape.NumAxes()) { + const int64_t axis_y = y_shape.NumAxes() - 1 - i; + const int64_t axis_z = z_shape.NumAxes() - 1 - i; + if (y_shape.At(axis_y) == z_shape.At(axis_z)) { ctx->NewBuilder() - .Broadcast(user_op::OpArg("y", 0)) - .PartialSum(user_op::OpArg("z", 0)) - .Broadcast(user_op::OpArg("dz", 0)) - .Broadcast(user_op::OpArg("dy", 0)) + .Split(user_op::OpArg("y", 0), axis_y) + .Split(user_op::OpArg("z", 0), axis_z) + .Split(user_op::OpArg("dz", 0), axis_z) + .Split(user_op::OpArg("dy", 0), axis_y) .Build(); + } else { ctx->NewBuilder() .Broadcast(user_op::OpArg("y", 0)) - .Broadcast(user_op::OpArg("z", 0)) - .PartialSum(user_op::OpArg("dz", 0)) - .Broadcast(user_op::OpArg("dy", 0)) + .Split(user_op::OpArg("z", 0), axis_z) + .Split(user_op::OpArg("dz", 0), axis_z) + .PartialSum(user_op::OpArg("dy", 0)) .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dy", 0) = ctx->InputDType("y", 0); - return Maybe::Ok(); - }); + } + } + ctx->NewBuilder() + .Broadcast(user_op::OpArg("y", 0)) + .PartialSum(user_op::OpArg("z", 0)) + .Broadcast(user_op::OpArg("dz", 0)) + .Broadcast(user_op::OpArg("dy", 0)) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("y", 0)) + .Broadcast(user_op::OpArg("z", 0)) + .PartialSum(user_op::OpArg("dz", 0)) + .Broadcast(user_op::OpArg("dy", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe BroadcastDivGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dy", 0) = ctx->InputDType("y", 0); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/broadcast_like_op.cpp b/oneflow/user/ops/broadcast_like_op.cpp index a1a54b6a407..6682f5ed2ea 100644 --- a/oneflow/user/ops/broadcast_like_op.cpp +++ b/oneflow/user/ops/broadcast_like_op.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/reduce_sbp_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -77,28 +78,32 @@ Maybe InferTensorDesc(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe InferDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("like", 0); - return Maybe::Ok(); +} // namespace + +/* static */ Maybe BroadcastLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDesc(ctx); } -} // namespace +/*static*/ Maybe BroadcastLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} -REGISTER_USER_OP("broadcast_like") - .Input("x") - .Input("like") - .Attr>("broadcast_axes") - .Output("y") - .SetTensorDescInferFn(InferTensorDesc) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); - CHECK_OR_RETURN(like_modifier != nullptr); - like_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetSbpSignatures) - .SetDataTypeInferFn(InferDataType); +/* static */ Maybe BroadcastLikeOp::GetSbp(user_op::SbpContext* ctx) { + return GetSbpSignatures(ctx); +} + +/* static */ Maybe BroadcastLikeOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); + CHECK_OR_RETURN(like_modifier != nullptr); + like_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe BroadcastLikeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("like", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("broadcast_like") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { diff --git a/oneflow/user/ops/broadcast_pow_grad_op.cpp b/oneflow/user/ops/broadcast_pow_grad_op.cpp index a203304c817..21fa575b03b 100644 --- a/oneflow/user/ops/broadcast_pow_grad_op.cpp +++ b/oneflow/user/ops/broadcast_pow_grad_op.cpp @@ -14,108 +14,110 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("broadcast_pow_x_grad") - .Input("x") - .Input("y") - .Input("z") - .Input("dz") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("dx", 0) = ctx->InputShape("x", 0); - *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); - const Shape& y_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0).shape(); - const Shape& z_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("z", 0).shape(); - CHECK_LE_OR_RETURN(x_shape.NumAxes(), z_shape.NumAxes()); - CHECK_LE_OR_RETURN(y_shape.NumAxes(), z_shape.NumAxes()); - FOR_RANGE(int64_t, i, 0, z_shape.NumAxes()) { - const int64_t _axis = z_shape.NumAxes() - 1 - i; - if (z_shape.At(_axis) == x_shape.At(_axis) && z_shape.At(_axis) == y_shape.At(_axis)) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), _axis) - .Split(user_op::OpArg("y", 0), _axis) - .Split(user_op::OpArg("z", 0), _axis) - .Split(user_op::OpArg("dz", 0), _axis) - .Split(user_op::OpArg("dx", 0), _axis) - .Build(); - } - } - ctx->NewBuilder() - .Broadcast(user_op::OpArg("y", 0)) - .PartialSum(user_op::OpArg("z", 0)) - .Broadcast(user_op::OpArg("dz", 0)) - .Broadcast(user_op::OpArg("x", 0)) - .Broadcast(user_op::OpArg("dx", 0)) - .Build(); - ctx->NewBuilder() - .PartialSum(user_op::OpArg("y", 0)) - .Broadcast(user_op::OpArg("z", 0)) - .Broadcast(user_op::OpArg("dz", 0)) - .Broadcast(user_op::OpArg("x", 0)) - .Broadcast(user_op::OpArg("dx", 0)) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("y", 0)) - .Broadcast(user_op::OpArg("z", 0)) - .PartialSum(user_op::OpArg("dz", 0)) - .Broadcast(user_op::OpArg("x", 0)) - .Broadcast(user_op::OpArg("dx", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); -REGISTER_USER_OP("broadcast_pow_y_grad") - .Input("x") - .Input("y") - .Input("z") - .Input("dz") - .Output("dy") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("dy", 0) = ctx->InputShape("y", 0); - *ctx->OutputIsDynamic("dy", 0) = ctx->InputIsDynamic("y", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); - const Shape& z_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("z", 0).shape(); - CHECK_LE_OR_RETURN(x_shape.NumAxes(), z_shape.NumAxes()); - FOR_RANGE(int64_t, i, 0, z_shape.NumAxes()) { - const int64_t _axis = z_shape.NumAxes() - 1 - i; - if (z_shape.At(_axis) == x_shape.At(_axis)) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), _axis) - .Split(user_op::OpArg("z", 0), _axis) - .Split(user_op::OpArg("dz", 0), _axis) - .Split(user_op::OpArg("dy", 0), _axis) - .Build(); - } - } +/* static */ Maybe BroadcastPowXGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("dx", 0) = ctx->InputShape("x", 0); + *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe BroadcastPowXGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe BroadcastPowXGradOp::GetSbp(user_op::SbpContext* ctx) { + const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); + const Shape& y_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0).shape(); + const Shape& z_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("z", 0).shape(); + CHECK_LE_OR_RETURN(x_shape.NumAxes(), z_shape.NumAxes()); + CHECK_LE_OR_RETURN(y_shape.NumAxes(), z_shape.NumAxes()); + FOR_RANGE(int64_t, i, 0, z_shape.NumAxes()) { + const int64_t _axis = z_shape.NumAxes() - 1 - i; + if (z_shape.At(_axis) == x_shape.At(_axis) && z_shape.At(_axis) == y_shape.At(_axis)) { ctx->NewBuilder() - .Broadcast(user_op::OpArg("x", 0)) - .PartialSum(user_op::OpArg("z", 0)) - .Broadcast(user_op::OpArg("dz", 0)) - .Broadcast(user_op::OpArg("dy", 0)) + .Split(user_op::OpArg("x", 0), _axis) + .Split(user_op::OpArg("y", 0), _axis) + .Split(user_op::OpArg("z", 0), _axis) + .Split(user_op::OpArg("dz", 0), _axis) + .Split(user_op::OpArg("dx", 0), _axis) .Build(); + } + } + ctx->NewBuilder() + .Broadcast(user_op::OpArg("y", 0)) + .PartialSum(user_op::OpArg("z", 0)) + .Broadcast(user_op::OpArg("dz", 0)) + .Broadcast(user_op::OpArg("x", 0)) + .Broadcast(user_op::OpArg("dx", 0)) + .Build(); + ctx->NewBuilder() + .PartialSum(user_op::OpArg("y", 0)) + .Broadcast(user_op::OpArg("z", 0)) + .Broadcast(user_op::OpArg("dz", 0)) + .Broadcast(user_op::OpArg("x", 0)) + .Broadcast(user_op::OpArg("dx", 0)) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("y", 0)) + .Broadcast(user_op::OpArg("z", 0)) + .PartialSum(user_op::OpArg("dz", 0)) + .Broadcast(user_op::OpArg("x", 0)) + .Broadcast(user_op::OpArg("dx", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe BroadcastPowXGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} + +/* static */ Maybe BroadcastPowYGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("dy", 0) = ctx->InputShape("y", 0); + *ctx->OutputIsDynamic("dy", 0) = ctx->InputIsDynamic("y", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe BroadcastPowYGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe BroadcastPowYGradOp::GetSbp(user_op::SbpContext* ctx) { + const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); + const Shape& z_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("z", 0).shape(); + CHECK_LE_OR_RETURN(x_shape.NumAxes(), z_shape.NumAxes()); + FOR_RANGE(int64_t, i, 0, z_shape.NumAxes()) { + const int64_t _axis = z_shape.NumAxes() - 1 - i; + if (z_shape.At(_axis) == x_shape.At(_axis)) { ctx->NewBuilder() - .Broadcast(user_op::OpArg("x", 0)) - .Broadcast(user_op::OpArg("z", 0)) - .PartialSum(user_op::OpArg("dz", 0)) - .Broadcast(user_op::OpArg("dy", 0)) + .Split(user_op::OpArg("x", 0), _axis) + .Split(user_op::OpArg("z", 0), _axis) + .Split(user_op::OpArg("dz", 0), _axis) + .Split(user_op::OpArg("dy", 0), _axis) .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dy", 0) = ctx->InputDType("y", 0); - return Maybe::Ok(); - }); + } + } + ctx->NewBuilder() + .Broadcast(user_op::OpArg("x", 0)) + .PartialSum(user_op::OpArg("z", 0)) + .Broadcast(user_op::OpArg("dz", 0)) + .Broadcast(user_op::OpArg("dy", 0)) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("x", 0)) + .Broadcast(user_op::OpArg("z", 0)) + .PartialSum(user_op::OpArg("dz", 0)) + .Broadcast(user_op::OpArg("dy", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe BroadcastPowYGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dy", 0) = ctx->InputDType("y", 0); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/buffer_op.cpp b/oneflow/user/ops/buffer_op.cpp index 3131d1b35b1..eb8abde1ee6 100644 --- a/oneflow/user/ops/buffer_op.cpp +++ b/oneflow/user/ops/buffer_op.cpp @@ -14,39 +14,35 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -REGISTER_NO_GRAD_USER_OP("identity_buffer") - .Input("in") - .Output("out") - .Attr("buffer_size") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); - -} // namespace +/* static */ Maybe IdentityBufferOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe IdentityBufferOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe IdentityBufferOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe IdentityBufferOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/cast_like_op.cpp b/oneflow/user/ops/cast_like_op.cpp index a2a7face17c..c4d41a00be8 100644 --- a/oneflow/user/ops/cast_like_op.cpp +++ b/oneflow/user/ops/cast_like_op.cpp @@ -14,56 +14,60 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("cast_like") - .Input("in") - .Input("dtype_like") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* dtype_like_modifier = GetInputArgModifierFn("dtype_like", 0); - CHECK_NOTNULL_OR_RETURN(dtype_like_modifier); - dtype_like_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); - for (int i = 0; i < in_shape.NumAxes(); ++i) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("dtype_like", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("dtype_like", 0)) - .Broadcast(user_op::OpArg("in", 0)) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("dtype_like", 0)) - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - ctx->NewBuilder() - .PartialSum(user_op::OpArg("dtype_like", 0)) - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& dtype_like_tensor_desc = ctx->InputTensorDesc("dtype_like", 0); - user_op::TensorDesc* output_tensor_desc = ctx->OutputTensorDesc("out", 0); - *output_tensor_desc->mut_data_type() = dtype_like_tensor_desc.data_type(); - return Maybe::Ok(); - }); +/* static */ Maybe CastLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe CastLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CastLikeOp::GetSbp(user_op::SbpContext* ctx) { + const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); + for (int i = 0; i < in_shape.NumAxes(); ++i) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .Split(user_op::OpArg("dtype_like", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("dtype_like", 0)) + .Broadcast(user_op::OpArg("in", 0)) + .Broadcast(user_op::OpArg("out", 0)) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("dtype_like", 0)) + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + ctx->NewBuilder() + .PartialSum(user_op::OpArg("dtype_like", 0)) + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe CastLikeOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* dtype_like_modifier = GetInputArgModifierFn("dtype_like", 0); + CHECK_NOTNULL_OR_RETURN(dtype_like_modifier); + dtype_like_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe CastLikeOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& dtype_like_tensor_desc = ctx->InputTensorDesc("dtype_like", 0); + user_op::TensorDesc* output_tensor_desc = ctx->OutputTensorDesc("out", 0); + *output_tensor_desc->mut_data_type() = dtype_like_tensor_desc.data_type(); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/cast_op.cpp b/oneflow/user/ops/cast_op.cpp index 2ae1c246be2..545bcfeaba3 100644 --- a/oneflow/user/ops/cast_op.cpp +++ b/oneflow/user/ops/cast_op.cpp @@ -14,11 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { -Maybe TensorDescInfer(user_op::InferContext* ctx) { +/* static */ Maybe CastOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& input_tensor_desc = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* output_tensor_desc = ctx->OutputTensorDesc("out", 0); *output_tensor_desc->mut_shape() = input_tensor_desc.shape(); @@ -26,7 +26,11 @@ Maybe TensorDescInfer(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe GetSbpSignatures(user_op::SbpContext* ctx) { +/*static*/ Maybe CastOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CastOp::GetSbp(user_op::SbpContext* ctx) { const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); for (int i = 0; i < in_tensor.shape().NumAxes(); ++i) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); @@ -35,21 +39,13 @@ Maybe GetSbpSignatures(user_op::SbpContext* ctx) { return Maybe::Ok(); } -Maybe InferDataType(user_op::InferContext* ctx) { +/* static */ Maybe CastOp::InferDataType(user_op::InferContext* ctx) { user_op::TensorDesc* output_tensor_desc = ctx->OutputTensorDesc("out", 0); DataType* dtype = output_tensor_desc->mut_data_type(); *dtype = ctx->Attr("dtype"); return Maybe::Ok(); } -REGISTER_USER_OP("cast") - .Input("in") - .Attr("dtype") - .Output("out") - .SetTensorDescInferFn(TensorDescInfer) - .SetGetSbpFn(GetSbpSignatures) - .SetDataTypeInferFn(InferDataType); - REGISTER_USER_OP_GRAD("cast").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { if (op.NeedGenGradTensor4OpInput("in", 0)) { @@ -67,5 +63,4 @@ REGISTER_USER_OP_GRAD("cast").SetGenBackwardOpConfFn([](const user_op::UserOpWra return Maybe::Ok(); }); -} // namespace } // namespace oneflow diff --git a/oneflow/user/ops/cast_to_static_shape_op.cpp b/oneflow/user/ops/cast_to_static_shape_op.cpp index 749f4940bf7..20843124a24 100644 --- a/oneflow/user/ops/cast_to_static_shape_op.cpp +++ b/oneflow/user/ops/cast_to_static_shape_op.cpp @@ -14,38 +14,41 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("cast_to_static_shape") - .Input("input") - .Output("output") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); - user_op::TensorDesc* output_desc = ctx->OutputTensorDesc("output", 0); - *output_desc->mut_shape() = input_desc.shape(); - output_desc->set_is_dynamic(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& input_desc = - ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0); - FOR_RANGE(int64_t, i, 0, input_desc.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("input", 0), i) - .Split(user_op::OpArg("output", 0), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("input", 0)) - .PartialSum(user_op::OpArg("output", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("output", 0) = ctx->InputDType("input", 0); - return Maybe::Ok(); - }); +/* static */ Maybe CastToStaticShapeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); + user_op::TensorDesc* output_desc = ctx->OutputTensorDesc("output", 0); + *output_desc->mut_shape() = input_desc.shape(); + output_desc->set_is_dynamic(false); + return Maybe::Ok(); +} + +/*static*/ Maybe CastToStaticShapeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CastToStaticShapeOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& input_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0); + FOR_RANGE(int64_t, i, 0, input_desc.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("input", 0), i) + .Split(user_op::OpArg("output", 0), i) + .Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("input", 0)) + .PartialSum(user_op::OpArg("output", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe CastToStaticShapeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("output", 0) = ctx->InputDType("input", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("cast_to_static_shape") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/cast_to_tick_op.cpp b/oneflow/user/ops/cast_to_tick_op.cpp index de7ec6b9d87..1daf6241bb0 100644 --- a/oneflow/user/ops/cast_to_tick_op.cpp +++ b/oneflow/user/ops/cast_to_tick_op.cpp @@ -15,43 +15,46 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/operator.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -REGISTER_NO_GRAD_USER_OP("cast_to_tick") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - Shape* out_shape = ctx->OutputShape("out", 0); - *out_shape = Shape({1}); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), parallel_hierarchy.NumAxes()); - - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - in_distribution->clear_sbp_parallel(); - out_distribution->clear_sbp_parallel(); - // in use hint - in_distribution->CopyFrom(in_dis_hint); - - for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { - // out dim1 = broadcast - out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); - } - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -} // namespace +/* static */ Maybe CastToTickOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + Shape* out_shape = ctx->OutputShape("out", 0); + *out_shape = Shape({1}); + return Maybe::Ok(); +} + +/*static*/ Maybe CastToTickOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CastToTickOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe CastToTickOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), parallel_hierarchy.NumAxes()); + + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + in_distribution->clear_sbp_parallel(); + out_distribution->clear_sbp_parallel(); + // in use hint + in_distribution->CopyFrom(in_dis_hint); + + for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { + // out dim1 = broadcast + out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); + } + return Maybe::Ok(); +} + +/* static */ Maybe CastToTickOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/categorical_ordinal_encode_op.cpp b/oneflow/user/ops/categorical_ordinal_encode_op.cpp index cb2f0d4351f..ca2b4533826 100644 --- a/oneflow/user/ops/categorical_ordinal_encode_op.cpp +++ b/oneflow/user/ops/categorical_ordinal_encode_op.cpp @@ -14,64 +14,66 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("CategoricalOrdinalEncode") - .Input("table") - .Input("size") - .Input("in") - .Output("out") - .Attr("hash_precomputed") - .SetPhysicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->parallel_ctx().parallel_num(), 1); - const Shape& table_shape = ctx->InputShape("table", 0); - CHECK_EQ_OR_RETURN(table_shape.NumAxes(), 1); - CHECK_EQ_OR_RETURN(table_shape.elem_cnt() % 2, 0); - const Shape& size_shape = ctx->InputShape("size", 0); - CHECK_EQ_OR_RETURN(size_shape.NumAxes(), 1); - CHECK_EQ_OR_RETURN(size_shape.elem_cnt(), 1); - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& table_shape = ctx->InputShape("table", 0); - CHECK_EQ_OR_RETURN(table_shape.NumAxes(), 1); - CHECK_EQ_OR_RETURN(table_shape.elem_cnt() % 2, 0); - const Shape& size_shape = ctx->InputShape("size", 0); - CHECK_EQ_OR_RETURN(size_shape.NumAxes(), 1); - CHECK_EQ_OR_RETURN(size_shape.elem_cnt(), 1); - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* table = GetInputArgModifierFn("table", 0); - table->set_is_mutable(true); - table->set_requires_grad(false); - user_op::InputArgModifier* size = GetInputArgModifierFn("size", 0); - size->set_is_mutable(true); - size->set_requires_grad(false); - user_op::InputArgModifier* in = GetInputArgModifierFn("in", 0); - in->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->parallel_num(), 1); - return Maybe::Ok(); - }) - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& op_def, - const user_op::UserOpConfWrapper& op_conf) -> Maybe { - CHECK_OR_RETURN(op_conf.attr("hash_precomputed")); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const DataType& data_type = ctx->InputDType("in", 0); - CHECK_OR_RETURN(IsIndexDataType(data_type)); - CHECK_EQ_OR_RETURN(ctx->InputDType("table", 0), data_type); - CHECK_EQ_OR_RETURN(ctx->InputDType("size", 0), data_type); - *ctx->OutputDType("out", 0) = data_type; - return Maybe::Ok(); - }); +/* static */ Maybe CategoricalOrdinalEncodeOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const Shape& table_shape = ctx->InputShape("table", 0); + CHECK_EQ_OR_RETURN(table_shape.NumAxes(), 1); + CHECK_EQ_OR_RETURN(table_shape.elem_cnt() % 2, 0); + const Shape& size_shape = ctx->InputShape("size", 0); + CHECK_EQ_OR_RETURN(size_shape.NumAxes(), 1); + CHECK_EQ_OR_RETURN(size_shape.elem_cnt(), 1); + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe CategoricalOrdinalEncodeOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->parallel_ctx().parallel_num(), 1); + const Shape& table_shape = ctx->InputShape("table", 0); + CHECK_EQ_OR_RETURN(table_shape.NumAxes(), 1); + CHECK_EQ_OR_RETURN(table_shape.elem_cnt() % 2, 0); + const Shape& size_shape = ctx->InputShape("size", 0); + CHECK_EQ_OR_RETURN(size_shape.NumAxes(), 1); + CHECK_EQ_OR_RETURN(size_shape.elem_cnt(), 1); + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe CategoricalOrdinalEncodeOp::GetSbp(user_op::SbpContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->parallel_num(), 1); + return Maybe::Ok(); +} + +/* static */ Maybe CategoricalOrdinalEncodeOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* table = GetInputArgModifierFn("table", 0); + table->set_is_mutable(true); + table->set_requires_grad(false); + user_op::InputArgModifier* size = GetInputArgModifierFn("size", 0); + size->set_is_mutable(true); + size->set_requires_grad(false); + user_op::InputArgModifier* in = GetInputArgModifierFn("in", 0); + in->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe CategoricalOrdinalEncodeOp::CheckAttr( + const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { + CHECK_OR_RETURN(conf.attr("hash_precomputed")); + return Maybe::Ok(); +} + +/* static */ Maybe CategoricalOrdinalEncodeOp::InferDataType(user_op::InferContext* ctx) { + const DataType& data_type = ctx->InputDType("in", 0); + CHECK_OR_RETURN(IsIndexDataType(data_type)); + CHECK_EQ_OR_RETURN(ctx->InputDType("table", 0), data_type); + CHECK_EQ_OR_RETURN(ctx->InputDType("size", 0), data_type); + *ctx->OutputDType("out", 0) = data_type; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/celu_op.cpp b/oneflow/user/ops/celu_op.cpp index 395c85a67f6..60d48152434 100644 --- a/oneflow/user/ops/celu_op.cpp +++ b/oneflow/user/ops/celu_op.cpp @@ -14,63 +14,62 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/* static */ Maybe CeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("celu") - .Input("in") - .Output("out") - .Attr("alpha") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe CeluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} -REGISTER_USER_OP("celu_grad") - .Input("x") - .Input("dy") - .Output("dx") - .Attr("alpha") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(dy_shape == x_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/* static */ Maybe CeluOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe CeluOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe CeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(dy_shape == x_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} + +/*static*/ Maybe CeluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CeluGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe CeluGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("celu").SetBackwardOpConfGenFn( [](user_op::BackwardOpConfContext* ctx) -> Maybe { @@ -90,6 +89,4 @@ REGISTER_USER_OP_GRAD("celu").SetBackwardOpConfGenFn( return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/clip_by_value_op.cpp b/oneflow/user/ops/clip_by_value_op.cpp index acadfc6ca01..f216e077816 100644 --- a/oneflow/user/ops/clip_by_value_op.cpp +++ b/oneflow/user/ops/clip_by_value_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -66,66 +67,45 @@ Maybe InferClipGradDataType(user_op::InferContext* ctx) { } // namespace -REGISTER_USER_OP("clip_by_scalar") - .Input("x") - .Attr("floating_min") - .Attr("integral_min") - .Attr("floating_max") - .Attr("integral_max") - .Output("y") - .SetTensorDescInferFn(InferClipTensorDesc) - .SetGetSbpFn(GetClipSbpSignature) - .SetDataTypeInferFn(InferClipTensorDataType); - -REGISTER_USER_OP("clip_by_scalar_min") - .Input("x") - .Attr("floating_min") - .Attr("integral_min") - .Output("y") - .SetTensorDescInferFn(InferClipTensorDesc) - .SetGetSbpFn(GetClipSbpSignature) - .SetDataTypeInferFn(InferClipTensorDataType); - -REGISTER_USER_OP("clip_by_scalar_max") - .Input("x") - .Attr("floating_max") - .Attr("integral_max") - .Output("y") - .SetTensorDescInferFn(InferClipTensorDesc) - .SetGetSbpFn(GetClipSbpSignature) - .SetDataTypeInferFn(InferClipTensorDataType); - -REGISTER_USER_OP("clip_by_scalar_grad") - .Input("dy") - .Input("x") - .Attr("floating_min") - .Attr("integral_min") - .Attr("floating_max") - .Attr("integral_max") - .Output("dx") - .SetTensorDescInferFn(InferClipGradTensorDesc) - .SetGetSbpFn(GetClipGradSbpSignature) - .SetDataTypeInferFn(InferClipGradDataType); - -REGISTER_USER_OP("clip_by_scalar_min_grad") - .Input("dy") - .Input("x") - .Attr("floating_min") - .Attr("integral_min") - .Output("dx") - .SetTensorDescInferFn(InferClipGradTensorDesc) - .SetGetSbpFn(GetClipGradSbpSignature) - .SetDataTypeInferFn(InferClipGradDataType); - -REGISTER_USER_OP("clip_by_scalar_max_grad") - .Input("dy") - .Input("x") - .Attr("floating_max") - .Attr("integral_max") - .Output("dx") - .SetTensorDescInferFn(InferClipGradTensorDesc) - .SetGetSbpFn(GetClipGradSbpSignature) - .SetDataTypeInferFn(InferClipGradDataType); +#define DEF_CLIP_BY_VALUE_OP(op_class_name_prefix) \ + /* static */ Maybe op_class_name_prefix##Op::InferLogicalTensorDesc( \ + user_op::InferContext* ctx) { \ + return InferClipTensorDesc(ctx); \ + } \ + \ + /*static*/ Maybe op_class_name_prefix##Op::InferPhysicalTensorDesc( \ + user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + \ + /* static */ Maybe op_class_name_prefix##Op::GetSbp(user_op::SbpContext* ctx) { \ + return GetClipSbpSignature(ctx); \ + } \ + \ + /* static */ Maybe op_class_name_prefix##Op::InferDataType(user_op::InferContext* ctx) { \ + return InferClipTensorDataType(ctx); \ + } \ + /* static */ Maybe op_class_name_prefix##GradOp::InferLogicalTensorDesc( \ + user_op::InferContext* ctx) { \ + return InferClipGradTensorDesc(ctx); \ + } \ + /*static*/ Maybe op_class_name_prefix##GradOp::InferPhysicalTensorDesc( \ + user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /* static */ Maybe op_class_name_prefix##GradOp::GetSbp(user_op::SbpContext* ctx) { \ + return GetClipGradSbpSignature(ctx); \ + } \ + /* static */ Maybe op_class_name_prefix##GradOp::InferDataType( \ + user_op::InferContext* ctx) { \ + return InferClipGradDataType(ctx); \ + } + +DEF_CLIP_BY_VALUE_OP(ClipByScalar) +DEF_CLIP_BY_VALUE_OP(ClipByScalarMin) +DEF_CLIP_BY_VALUE_OP(ClipByScalarMax) + +#undef DEF_CLIP_BY_VALUE_OP REGISTER_USER_OP_GRAD("clip_by_scalar") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/coco_reader_op.cpp b/oneflow/user/ops/coco_reader_op.cpp index 6ab6f25457d..adfca1c99bf 100644 --- a/oneflow/user/ops/coco_reader_op.cpp +++ b/oneflow/user/ops/coco_reader_op.cpp @@ -14,135 +14,122 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("COCOReader") - .Output("image") - .Output("image_id") - .Output("image_size") - .Output("gt_bbox") - .Output("gt_label") - .Output("gt_segm") - .Output("gt_segm_index") - .Attr("session_id") - .Attr("annotation_file") - .Attr("image_dir") - .Attr("batch_size") - .Attr("shuffle_after_epoch", true) - .Attr("random_seed", -1) - .Attr("group_by_ratio", true) - .Attr("remove_images_without_annotations", true) - .Attr("stride_partition", false) - .Attr>("nd_sbp") - .SetPhysicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const cfg::SbpParallel& sbp = ctx->SbpParallel4ArgNameAndIndex("image", 0); - CHECK_OR_RETURN(sbp == ctx->SbpParallel4ArgNameAndIndex("image_id", 0)); - CHECK_OR_RETURN(sbp == ctx->SbpParallel4ArgNameAndIndex("image_size", 0)); - CHECK_OR_RETURN(sbp == ctx->SbpParallel4ArgNameAndIndex("gt_bbox", 0)); - CHECK_OR_RETURN(sbp == ctx->SbpParallel4ArgNameAndIndex("gt_label", 0)); - CHECK_OR_RETURN(sbp == ctx->SbpParallel4ArgNameAndIndex("gt_segm", 0)); - CHECK_OR_RETURN(sbp == ctx->SbpParallel4ArgNameAndIndex("gt_segm_index", 0)); - - int64_t batch_size = ctx->Attr("batch_size"); - int64_t parallel_num = ctx->parallel_ctx().parallel_num(); - int64_t device_batch_size = batch_size; - if (sbp.has_split_parallel() && parallel_num > 1) { - CHECK_EQ_OR_RETURN(device_batch_size % parallel_num, 0); - device_batch_size /= parallel_num; - } - - user_op::TensorDesc* image_desc = ctx->OutputTensorDesc("image", 0); - *image_desc->mut_shape() = Shape({device_batch_size}); - user_op::TensorDesc* image_id_desc = ctx->OutputTensorDesc("image_id", 0); - *image_id_desc->mut_shape() = Shape({device_batch_size}); - user_op::TensorDesc* image_size_desc = ctx->OutputTensorDesc("image_size", 0); - *image_size_desc->mut_shape() = Shape({device_batch_size, 2}); - user_op::TensorDesc* bbox_desc = ctx->OutputTensorDesc("gt_bbox", 0); - *bbox_desc->mut_shape() = Shape({device_batch_size}); - user_op::TensorDesc* label_desc = ctx->OutputTensorDesc("gt_label", 0); - *label_desc->mut_shape() = Shape({device_batch_size}); - user_op::TensorDesc* segm_desc = ctx->OutputTensorDesc("gt_segm", 0); - *segm_desc->mut_shape() = Shape({device_batch_size}); - user_op::TensorDesc* segm_index_desc = ctx->OutputTensorDesc("gt_segm_index", 0); - *segm_index_desc->mut_shape() = Shape({device_batch_size}); - return Maybe::Ok(); - }) - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - int64_t batch_size = ctx->Attr("batch_size"); - user_op::TensorDesc* image_desc = ctx->OutputTensorDesc("image", 0); - *image_desc->mut_shape() = Shape({batch_size}); - user_op::TensorDesc* image_id_desc = ctx->OutputTensorDesc("image_id", 0); - *image_id_desc->mut_shape() = Shape({batch_size}); - user_op::TensorDesc* image_size_desc = ctx->OutputTensorDesc("image_size", 0); - *image_size_desc->mut_shape() = Shape({batch_size, 2}); - user_op::TensorDesc* bbox_desc = ctx->OutputTensorDesc("gt_bbox", 0); - *bbox_desc->mut_shape() = Shape({batch_size}); - user_op::TensorDesc* label_desc = ctx->OutputTensorDesc("gt_label", 0); - *label_desc->mut_shape() = Shape({batch_size}); - user_op::TensorDesc* segm_desc = ctx->OutputTensorDesc("gt_segm", 0); - *segm_desc->mut_shape() = Shape({batch_size}); - user_op::TensorDesc* segm_index_desc = ctx->OutputTensorDesc("gt_segm_index", 0); - *segm_index_desc->mut_shape() = Shape({batch_size}); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* image_desc = ctx->OutputTensorDesc("image", 0); - *image_desc->mut_data_type() = DataType::kTensorBuffer; - user_op::TensorDesc* image_id_desc = ctx->OutputTensorDesc("image_id", 0); - *image_id_desc->mut_data_type() = DataType::kInt64; - user_op::TensorDesc* image_size_desc = ctx->OutputTensorDesc("image_size", 0); - *image_size_desc->mut_data_type() = DataType::kInt32; - user_op::TensorDesc* bbox_desc = ctx->OutputTensorDesc("gt_bbox", 0); - *bbox_desc->mut_data_type() = DataType::kTensorBuffer; - user_op::TensorDesc* label_desc = ctx->OutputTensorDesc("gt_label", 0); - *label_desc->mut_data_type() = DataType::kTensorBuffer; - user_op::TensorDesc* segm_desc = ctx->OutputTensorDesc("gt_segm", 0); - *segm_desc->mut_data_type() = DataType::kTensorBuffer; - user_op::TensorDesc* segm_index_desc = ctx->OutputTensorDesc("gt_segm_index", 0); - *segm_index_desc->mut_data_type() = DataType::kTensorBuffer; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - cfg::SbpParallel default_sbp; - default_sbp.mutable_split_parallel()->set_axis(0); - return user_op::InferNdSbp4SrcOp(ctx, default_sbp); - }) - .SetOutputArgModifyFn([](user_op::GetOutputArgModifier GetOutputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> Maybe { - user_op::OutputArgModifier* image_modifier = GetOutputArgModifierFn("image", 0); - CHECK_OR_RETURN(image_modifier != nullptr); - image_modifier->set_header_infered_before_compute(false); - - user_op::OutputArgModifier* image_id_modifier = GetOutputArgModifierFn("image_id", 0); - CHECK_OR_RETURN(image_id_modifier != nullptr); - image_id_modifier->set_header_infered_before_compute(false); - - user_op::OutputArgModifier* image_size_modifier = GetOutputArgModifierFn("image_size", 0); - CHECK_OR_RETURN(image_size_modifier != nullptr); - image_size_modifier->set_header_infered_before_compute(false); - - user_op::OutputArgModifier* gt_bbox_modifier = GetOutputArgModifierFn("gt_bbox", 0); - CHECK_OR_RETURN(gt_bbox_modifier != nullptr); - gt_bbox_modifier->set_header_infered_before_compute(false); - - user_op::OutputArgModifier* gt_label_modifier = GetOutputArgModifierFn("gt_label", 0); - CHECK_OR_RETURN(gt_label_modifier != nullptr); - gt_label_modifier->set_header_infered_before_compute(false); - - user_op::OutputArgModifier* gt_segm_modifier = GetOutputArgModifierFn("gt_segm", 0); - CHECK_OR_RETURN(gt_segm_modifier != nullptr); - gt_segm_modifier->set_header_infered_before_compute(false); - - user_op::OutputArgModifier* gt_segm_index_modifier = - GetOutputArgModifierFn("gt_segm_index", 0); - CHECK_OR_RETURN(gt_segm_index_modifier != nullptr); - gt_segm_index_modifier->set_header_infered_before_compute(false); - return Maybe::Ok(); - }); +/* static */ Maybe COCOReaderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + int64_t batch_size = ctx->Attr("batch_size"); + user_op::TensorDesc* image_desc = ctx->OutputTensorDesc("image", 0); + *image_desc->mut_shape() = Shape({batch_size}); + user_op::TensorDesc* image_id_desc = ctx->OutputTensorDesc("image_id", 0); + *image_id_desc->mut_shape() = Shape({batch_size}); + user_op::TensorDesc* image_size_desc = ctx->OutputTensorDesc("image_size", 0); + *image_size_desc->mut_shape() = Shape({batch_size, 2}); + user_op::TensorDesc* bbox_desc = ctx->OutputTensorDesc("gt_bbox", 0); + *bbox_desc->mut_shape() = Shape({batch_size}); + user_op::TensorDesc* label_desc = ctx->OutputTensorDesc("gt_label", 0); + *label_desc->mut_shape() = Shape({batch_size}); + user_op::TensorDesc* segm_desc = ctx->OutputTensorDesc("gt_segm", 0); + *segm_desc->mut_shape() = Shape({batch_size}); + user_op::TensorDesc* segm_index_desc = ctx->OutputTensorDesc("gt_segm_index", 0); + *segm_index_desc->mut_shape() = Shape({batch_size}); + return Maybe::Ok(); +} + +/* static */ Maybe COCOReaderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + const cfg::SbpParallel& sbp = ctx->SbpParallel4ArgNameAndIndex("image", 0); + CHECK_OR_RETURN(sbp == ctx->SbpParallel4ArgNameAndIndex("image_id", 0)); + CHECK_OR_RETURN(sbp == ctx->SbpParallel4ArgNameAndIndex("image_size", 0)); + CHECK_OR_RETURN(sbp == ctx->SbpParallel4ArgNameAndIndex("gt_bbox", 0)); + CHECK_OR_RETURN(sbp == ctx->SbpParallel4ArgNameAndIndex("gt_label", 0)); + CHECK_OR_RETURN(sbp == ctx->SbpParallel4ArgNameAndIndex("gt_segm", 0)); + CHECK_OR_RETURN(sbp == ctx->SbpParallel4ArgNameAndIndex("gt_segm_index", 0)); + + int64_t batch_size = ctx->Attr("batch_size"); + int64_t parallel_num = ctx->parallel_ctx().parallel_num(); + int64_t device_batch_size = batch_size; + if (sbp.has_split_parallel() && parallel_num > 1) { + CHECK_EQ_OR_RETURN(device_batch_size % parallel_num, 0); + device_batch_size /= parallel_num; + } + + user_op::TensorDesc* image_desc = ctx->OutputTensorDesc("image", 0); + *image_desc->mut_shape() = Shape({device_batch_size}); + user_op::TensorDesc* image_id_desc = ctx->OutputTensorDesc("image_id", 0); + *image_id_desc->mut_shape() = Shape({device_batch_size}); + user_op::TensorDesc* image_size_desc = ctx->OutputTensorDesc("image_size", 0); + *image_size_desc->mut_shape() = Shape({device_batch_size, 2}); + user_op::TensorDesc* bbox_desc = ctx->OutputTensorDesc("gt_bbox", 0); + *bbox_desc->mut_shape() = Shape({device_batch_size}); + user_op::TensorDesc* label_desc = ctx->OutputTensorDesc("gt_label", 0); + *label_desc->mut_shape() = Shape({device_batch_size}); + user_op::TensorDesc* segm_desc = ctx->OutputTensorDesc("gt_segm", 0); + *segm_desc->mut_shape() = Shape({device_batch_size}); + user_op::TensorDesc* segm_index_desc = ctx->OutputTensorDesc("gt_segm_index", 0); + *segm_index_desc->mut_shape() = Shape({device_batch_size}); + return Maybe::Ok(); +} + +/* static */ Maybe COCOReaderOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe COCOReaderOp::ModifyOutputArg( + const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::OutputArgModifier* image_modifier = GetOutputArgModifierFn("image", 0); + CHECK_OR_RETURN(image_modifier != nullptr); + image_modifier->set_header_infered_before_compute(false); + + user_op::OutputArgModifier* image_id_modifier = GetOutputArgModifierFn("image_id", 0); + CHECK_OR_RETURN(image_id_modifier != nullptr); + image_id_modifier->set_header_infered_before_compute(false); + + user_op::OutputArgModifier* image_size_modifier = GetOutputArgModifierFn("image_size", 0); + CHECK_OR_RETURN(image_size_modifier != nullptr); + image_size_modifier->set_header_infered_before_compute(false); + + user_op::OutputArgModifier* gt_bbox_modifier = GetOutputArgModifierFn("gt_bbox", 0); + CHECK_OR_RETURN(gt_bbox_modifier != nullptr); + gt_bbox_modifier->set_header_infered_before_compute(false); + + user_op::OutputArgModifier* gt_label_modifier = GetOutputArgModifierFn("gt_label", 0); + CHECK_OR_RETURN(gt_label_modifier != nullptr); + gt_label_modifier->set_header_infered_before_compute(false); + + user_op::OutputArgModifier* gt_segm_modifier = GetOutputArgModifierFn("gt_segm", 0); + CHECK_OR_RETURN(gt_segm_modifier != nullptr); + gt_segm_modifier->set_header_infered_before_compute(false); + + user_op::OutputArgModifier* gt_segm_index_modifier = GetOutputArgModifierFn("gt_segm_index", 0); + CHECK_OR_RETURN(gt_segm_index_modifier != nullptr); + gt_segm_index_modifier->set_header_infered_before_compute(false); + return Maybe::Ok(); +} + +/* static */ Maybe COCOReaderOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + cfg::SbpParallel default_sbp; + default_sbp.mutable_split_parallel()->set_axis(0); + return user_op::InferNdSbp4SrcOp(ctx, default_sbp); +} + +/* static */ Maybe COCOReaderOp::InferDataType(user_op::InferContext* ctx) { + user_op::TensorDesc* image_desc = ctx->OutputTensorDesc("image", 0); + *image_desc->mut_data_type() = DataType::kTensorBuffer; + user_op::TensorDesc* image_id_desc = ctx->OutputTensorDesc("image_id", 0); + *image_id_desc->mut_data_type() = DataType::kInt64; + user_op::TensorDesc* image_size_desc = ctx->OutputTensorDesc("image_size", 0); + *image_size_desc->mut_data_type() = DataType::kInt32; + user_op::TensorDesc* bbox_desc = ctx->OutputTensorDesc("gt_bbox", 0); + *bbox_desc->mut_data_type() = DataType::kTensorBuffer; + user_op::TensorDesc* label_desc = ctx->OutputTensorDesc("gt_label", 0); + *label_desc->mut_data_type() = DataType::kTensorBuffer; + user_op::TensorDesc* segm_desc = ctx->OutputTensorDesc("gt_segm", 0); + *segm_desc->mut_data_type() = DataType::kTensorBuffer; + user_op::TensorDesc* segm_index_desc = ctx->OutputTensorDesc("gt_segm_index", 0); + *segm_index_desc->mut_data_type() = DataType::kTensorBuffer; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/combined_margin_loss_op.cpp b/oneflow/user/ops/combined_margin_loss_op.cpp index d420cb35209..72854a53928 100644 --- a/oneflow/user/ops/combined_margin_loss_op.cpp +++ b/oneflow/user/ops/combined_margin_loss_op.cpp @@ -14,96 +14,94 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("combined_margin_loss") - .Input("x") - .Input("label") - .Output("y") - .Output("theta") - .Attr("m1") - .Attr("m2") - .Attr("m3") - .Attr("depth") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); - user_op::TensorDesc* theta = ctx->OutputTensorDesc("theta", 0); - CHECK_EQ_OR_RETURN(label.shape().At(0), x.shape().At(0)); - CHECK_GE_OR_RETURN(x.shape().NumAxes(), 2); - *ctx->OutputShape("y", 0) = ctx->InputShape("x", 0); - *ctx->IsDynamic4ArgNameAndIndex("y", 0) = ctx->InputIsDynamic("x", 0); - *theta->mut_is_dynamic() = x.is_dynamic(); - *theta->mut_shape() = label.shape(); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* label_arg_modifier = GetInputArgModifierFn("label", 0); - label_arg_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), 0) - .Split(user_op::OpArg("label", 0), 0) - .Split(user_op::OpArg("y", 0), 0) - .Split(user_op::OpArg("theta", 0), 0) - .Build(); - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), 1) - .Broadcast(user_op::OpArg("label", 0)) - .Split(user_op::OpArg("y", 0), 1) - .PartialSum(user_op::OpArg("theta", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - *ctx->OutputDType("theta", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/* static */ Maybe CombinedMarginLossOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); + user_op::TensorDesc* theta = ctx->OutputTensorDesc("theta", 0); + CHECK_EQ_OR_RETURN(label.shape().At(0), x.shape().At(0)); + CHECK_GE_OR_RETURN(x.shape().NumAxes(), 2); + *ctx->OutputShape("y", 0) = ctx->InputShape("x", 0); + *ctx->IsDynamic4ArgNameAndIndex("y", 0) = ctx->InputIsDynamic("x", 0); + *theta->mut_is_dynamic() = x.is_dynamic(); + *theta->mut_shape() = label.shape(); + return Maybe::Ok(); +} -REGISTER_USER_OP("combined_margin_loss_grad") - .Input("dy") - .Input("label") - .Input("theta") - .Output("dx") - .Attr("m1") - .Attr("m2") - .Attr("m3") - .Attr("depth") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); - const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); - const user_op::TensorDesc& theta = ctx->InputTensorDesc("theta", 0); - CHECK_EQ_OR_RETURN(label.shape().At(0), dy.shape().At(0)); - CHECK_EQ_OR_RETURN(label.shape().At(0), theta.shape().At(0)); - CHECK_GE_OR_RETURN(dy.shape().NumAxes(), 2); - *ctx->OutputShape("dx", 0) = ctx->InputShape("dy", 0); - *ctx->IsDynamic4ArgNameAndIndex("dx", 0) = ctx->InputIsDynamic("dy", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), 0) - .Split(user_op::OpArg("label", 0), 0) - .Split(user_op::OpArg("theta", 0), 0) - .Split(user_op::OpArg("dx", 0), 0) - .Build(); - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), 1) - .Broadcast(user_op::OpArg("label", 0)) - .Broadcast(user_op::OpArg("theta", 0)) - .Split(user_op::OpArg("dx", 0), 1) - .Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe CombinedMarginLossOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CombinedMarginLossOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), 0) + .Split(user_op::OpArg("label", 0), 0) + .Split(user_op::OpArg("y", 0), 0) + .Split(user_op::OpArg("theta", 0), 0) + .Build(); + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), 1) + .Broadcast(user_op::OpArg("label", 0)) + .Split(user_op::OpArg("y", 0), 1) + .PartialSum(user_op::OpArg("theta", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe CombinedMarginLossOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* label_arg_modifier = GetInputArgModifierFn("label", 0); + label_arg_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe CombinedMarginLossOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + *ctx->OutputDType("theta", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} + +/* static */ Maybe CombinedMarginLossGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); + const user_op::TensorDesc& theta = ctx->InputTensorDesc("theta", 0); + CHECK_EQ_OR_RETURN(label.shape().At(0), dy.shape().At(0)); + CHECK_EQ_OR_RETURN(label.shape().At(0), theta.shape().At(0)); + CHECK_GE_OR_RETURN(dy.shape().NumAxes(), 2); + *ctx->OutputShape("dx", 0) = ctx->InputShape("dy", 0); + *ctx->IsDynamic4ArgNameAndIndex("dx", 0) = ctx->InputIsDynamic("dy", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe CombinedMarginLossGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CombinedMarginLossGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), 0) + .Split(user_op::OpArg("label", 0), 0) + .Split(user_op::OpArg("theta", 0), 0) + .Split(user_op::OpArg("dx", 0), 0) + .Build(); + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), 1) + .Broadcast(user_op::OpArg("label", 0)) + .Broadcast(user_op::OpArg("theta", 0)) + .Split(user_op::OpArg("dx", 0), 1) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe CombinedMarginLossGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("combined_margin_loss") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/concat_op.cpp b/oneflow/user/ops/concat_op.cpp index 9ba06a0022c..b631d4a15c8 100644 --- a/oneflow/user/ops/concat_op.cpp +++ b/oneflow/user/ops/concat_op.cpp @@ -14,12 +14,40 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { -Maybe InferTensorDesc(user_op::InferContext* ctx) { +Maybe GenGrapOp(const user_op::UserOpWrapper& op, const user_op::AddOpFn& AddOp) { + bool need_grad = false; + const int32_t in_size = op.input_size("in"); + FOR_RANGE(int32_t, i, 0, in_size) { + if (op.NeedGenGradTensor4OpInput("in", i)) { need_grad = true; } + } + if (need_grad) { + user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); + builder = builder.Op("split_like"); + FOR_RANGE(int32_t, i, 0, in_size) { builder = builder.Input("like", op.input("in", i)); } + user_op::UserOpConfWrapper grad_op = builder.Input("in", op.GetGradTensorWithOpOutput("out", 0)) + .Output("out", in_size) + .Attr("axis", op.attr("axis")) + .Build(); + + FOR_RANGE(int32_t, i, 0, in_size) { + if (op.NeedGenGradTensor4OpInput("in", i)) { + op.BindGradTensorWithOpInput(grad_op.output("out", i), "in", i); + } + } + AddOp(grad_op); + } + return Maybe::Ok(); +} + +} // namespace + +/* static */ Maybe ConcatOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc("in", 0); const int64_t axis = ctx->Attr("axis"); CHECK_GE_OR_RETURN(axis, 0); @@ -57,7 +85,11 @@ Maybe InferTensorDesc(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe GetSbpSignature(user_op::SbpContext* ctx) { +/*static*/ Maybe ConcatOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ConcatOp::GetSbp(user_op::SbpContext* ctx) { const int64_t axis = ctx->Attr("axis"); const user_op::TensorDesc& first_in_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); FOR_RANGE(int64_t, i, 0, first_in_desc.shape().NumAxes()) { @@ -68,32 +100,7 @@ Maybe GetSbpSignature(user_op::SbpContext* ctx) { return Maybe::Ok(); } -Maybe GenGrapOp(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) { - bool need_grad = false; - const int32_t in_size = op.input_size("in"); - FOR_RANGE(int32_t, i, 0, in_size) { - if (op.NeedGenGradTensor4OpInput("in", i)) { need_grad = true; } - } - if (need_grad) { - user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); - builder = builder.Op("split_like"); - FOR_RANGE(int32_t, i, 0, in_size) { builder = builder.Input("like", op.input("in", i)); } - user_op::UserOpConfWrapper grad_op = builder.Input("in", op.GetGradTensorWithOpOutput("out", 0)) - .Output("out", in_size) - .Attr("axis", op.attr("axis")) - .Build(); - - FOR_RANGE(int32_t, i, 0, in_size) { - if (op.NeedGenGradTensor4OpInput("in", i)) { - op.BindGradTensorWithOpInput(grad_op.output("out", i), "in", i); - } - } - AddOp(grad_op); - } - return Maybe::Ok(); -} - -Maybe InferDataType(user_op::InferContext* ctx) { +/* static */ Maybe ConcatOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& first_in_desc = ctx->InputTensorDesc("in", 0); for (const auto& in_arg_pair : ctx->inputs()) { const user_op::TensorDesc& in_desc = @@ -105,16 +112,11 @@ Maybe InferDataType(user_op::InferContext* ctx) { return Maybe::Ok(); } -} // namespace - -REGISTER_USER_OP("concat") - .InputWithMinimum("in", 1) - .Output("out") - .Attr("axis") - .Attr("max_dim_size") - .SetTensorDescInferFn(InferTensorDesc) - .SetGetSbpFn(GetSbpSignature) - .SetDataTypeInferFn(InferDataType); +/*static*/ Maybe ConcatOp::CheckAttr(const user_op::UserOpDefWrapper&, + const user_op::UserOpConfWrapper& op_conf) { + CHECK_OR_RETURN(op_conf.input_size("in") >= 2); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("concat").SetGenBackwardOpConfFn(GenGrapOp); diff --git a/oneflow/user/ops/constant_op.cpp b/oneflow/user/ops/constant_op.cpp index d1b432f4760..d96f07e7562 100644 --- a/oneflow/user/ops/constant_op.cpp +++ b/oneflow/user/ops/constant_op.cpp @@ -14,31 +14,26 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("constant") - .Output("out") - .SetOutputBufferNum(1) - .Attr("floating_value") - .Attr("integer_value") - .Attr("is_floating_value") - .Attr("dtype") - .Attr("shape") - .Attr>("nd_sbp") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { return Maybe::Ok(); }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - cfg::SbpParallel default_sbp; - default_sbp.mutable_broadcast_parallel(); - return user_op::InferNdSbp4SrcOp(ctx, default_sbp); - }); +/* static */ Maybe ConstantOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); + return Maybe::Ok(); +} + +/* static */ Maybe ConstantOp::GetSbp(user_op::SbpContext* ctx) { return Maybe::Ok(); } + +/* static */ Maybe ConstantOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + cfg::SbpParallel default_sbp; + default_sbp.mutable_broadcast_parallel(); + return user_op::InferNdSbp4SrcOp(ctx, default_sbp); +} + +/* static */ Maybe ConstantOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/conv_op.cpp b/oneflow/user/ops/conv_op.cpp index 37852e2e97a..64940f4d2da 100644 --- a/oneflow/user/ops/conv_op.cpp +++ b/oneflow/user/ops/conv_op.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -114,8 +115,8 @@ Maybe GetSbpSignatures4Conv(user_op::SbpContext* ctx) { } template -Maybe CheckAttr(const user_op::UserOpDefWrapper& def, - const user_op::UserOpConfWrapper& conf) { +Maybe CheckAttr_(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { bool is_checked = true; std::stringstream err; err << "Illegal value for " << conf.op_type_name() << " op " << conf.op_name() << ": "; @@ -229,241 +230,239 @@ Maybe GenerateBackwardOpConf4Conv(const user_op::UserOpWrapper& op, user_o } // namespace -REGISTER_USER_OP("conv1d") - .Input("in") - .Input("weight") - .OptionalInput("bias") - .OptionalInput("bias_multiplier") // cudnn conv doesn't need this - .Output("out") - .Attr("filters") - .Attr>("padding_before") - .Attr("data_format") - .Attr>("kernel_size") - .Attr>("strides") - .Attr>("dilation_rate") - .Attr("groups", 1) - .SetCheckAttrFn(CheckAttr<1>) - .SetTensorDescInferFn(InferTensorDesc4Conv<1>) - .SetGetSbpFn(GetSbpSignatures4Conv) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); - -REGISTER_USER_OP("conv2d") - .Input("in") - .Input("weight") - .OptionalInput("bias") - .OptionalInput("bias_multiplier") // cudnn conv doesn't need this - .Output("out") - .Attr("filters") - .Attr>("padding_before") - .Attr("data_format") - .Attr>("kernel_size") - .Attr>("strides") - .Attr>("dilation_rate") - .Attr("groups", 1) - .SetCheckAttrFn(CheckAttr<2>) - .SetTensorDescInferFn(InferTensorDesc4Conv<2>) - .SetGetSbpFn(GetSbpSignatures4Conv) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); - -REGISTER_USER_OP("conv3d") - .Input("in") - .Input("weight") - .OptionalInput("bias") - .OptionalInput("bias_multiplier") // cudnn conv doesn't need this - .Output("out") - .Attr("filters") - .Attr>("padding_before") - .Attr("data_format") - .Attr>("kernel_size") - .Attr>("strides") - .Attr>("dilation_rate") - .Attr("groups", 1) - .SetCheckAttrFn(CheckAttr<3>) - .SetTensorDescInferFn(InferTensorDesc4Conv<3>) - .SetGetSbpFn(GetSbpSignatures4Conv) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/* static */ Maybe Conv1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDesc4Conv<1>(ctx); +} + +/*static*/ Maybe Conv1DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe Conv1DOp::GetSbp(user_op::SbpContext* ctx) { + return GetSbpSignatures4Conv(ctx); +} + +/* static */ Maybe Conv1DOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + return CheckAttr_<1>(def, conf); +} + +/* static */ Maybe Conv1DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe Conv2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDesc4Conv<2>(ctx); +} + +/*static*/ Maybe Conv2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe Conv2DOp::GetSbp(user_op::SbpContext* ctx) { + return GetSbpSignatures4Conv(ctx); +} + +/* static */ Maybe Conv2DOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + return CheckAttr_<2>(def, conf); +} + +/* static */ Maybe Conv2DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe Conv3DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDesc4Conv<3>(ctx); +} + +/*static*/ Maybe Conv3DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe Conv3DOp::GetSbp(user_op::SbpContext* ctx) { + return GetSbpSignatures4Conv(ctx); +} + +/* static */ Maybe Conv3DOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + return CheckAttr_<3>(def, conf); +} + +/* static */ Maybe Conv3DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe ConvDataGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& x_like = ctx->InputTensorDesc("x_like", 0); + const int32_t num_spatial_dims = ctx->Attr("num_spatial_dims"); + CHECK_GE_OR_RETURN(num_spatial_dims, 1); + CHECK_LE_OR_RETURN(num_spatial_dims, 3); + CHECK_EQ_OR_RETURN(dy.shape().NumAxes(), num_spatial_dims + 2); + CHECK_EQ_OR_RETURN(x_like.shape().NumAxes(), num_spatial_dims + 2); + if (ctx->has_input("_add_to_output", 0)) { + const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); + CHECK_EQ_OR_RETURN(add_to_output.shape(), x_like.shape()); + } + *ctx->OutputShape("dx", 0) = ctx->InputShape("x_like", 0); + *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x_like", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe ConvDataGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ConvDataGradOp::GetSbp(user_op::SbpContext* ctx) { + std::vector split_args; + split_args.emplace_back("dy", 0); + split_args.emplace_back("x_like", 0); + split_args.emplace_back("dx", 0); + if (ctx->user_op_conf().has_input("_add_to_output", 0)) { + split_args.emplace_back("_add_to_output", 0); + } + ctx->NewBuilder().Split(split_args, 0).Broadcast(user_op::OpArg("filter", 0)).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ConvDataGradOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + return CheckAttr_<0>(def, conf); +} + +/* static */ Maybe ConvDataGradOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& x_like = ctx->InputTensorDesc("x_like", 0); + CHECK_EQ_OR_RETURN(x_like.data_type(), dy.data_type()); + if (ctx->has_input("_add_to_output", 0)) { + const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); + CHECK_EQ_OR_RETURN(add_to_output.data_type(), x_like.data_type()); + } + *ctx->OutputDType("dx", 0) = ctx->InputDType("x_like", 0); + return Maybe::Ok(); +} + +/* static */ Maybe ConvFilterGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + + const int32_t num_spatial_dims = ctx->Attr("num_spatial_dims"); + const int32_t groups = ctx->Attr("groups"); + const std::string& data_format = ctx->Attr("data_format"); + const std::vector kernel_size = ctx->Attr>("kernel_size"); + + CHECK_GE_OR_RETURN(num_spatial_dims, 1); + CHECK_LE_OR_RETURN(num_spatial_dims, 3); + CHECK_EQ_OR_RETURN(dy.shape().NumAxes(), num_spatial_dims + 2); + CHECK_EQ_OR_RETURN(x.shape().NumAxes(), num_spatial_dims + 2); + CHECK_GT_OR_RETURN(groups, 0); + + DimVector filter_diff_dim_vec; + if (data_format == "channels_first") { + CHECK_LE_OR_RETURN(groups, x.shape().At(1)); + CHECK_LE_OR_RETURN(groups, dy.shape().At(1)); + CHECK_EQ_OR_RETURN(x.shape().At(1) % groups, 0); + CHECK_EQ_OR_RETURN(dy.shape().At(1) % groups, 0); + filter_diff_dim_vec.emplace_back(dy.shape().At(1)); + filter_diff_dim_vec.emplace_back(x.shape().At(1) / groups); + filter_diff_dim_vec.insert(filter_diff_dim_vec.end(), kernel_size.cbegin(), kernel_size.cend()); + } else { + CHECK_EQ_OR_RETURN("channels_last", data_format); + CHECK_EQ_OR_RETURN(groups, 1); + filter_diff_dim_vec.emplace_back(dy.shape().dim_vec().back()); + filter_diff_dim_vec.insert(filter_diff_dim_vec.end(), kernel_size.cbegin(), kernel_size.cend()); + filter_diff_dim_vec.emplace_back(x.shape().dim_vec().back() / groups); + } + + user_op::TensorDesc* filter_diff = ctx->OutputTensorDesc("filter_diff", 0); + *filter_diff->mut_shape() = Shape(filter_diff_dim_vec); + filter_diff->set_is_dynamic(false); + + return Maybe::Ok(); +} + +/*static*/ Maybe ConvFilterGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ConvFilterGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), 0) + .Split(user_op::OpArg("x", 0), 0) + .PartialSum(user_op::OpArg("filter_diff", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ConvFilterGradOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + return CheckAttr_<0>(def, conf); +} + +/* static */ Maybe ConvFilterGradOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + CHECK_EQ_OR_RETURN(x.data_type(), dy.data_type()); + user_op::TensorDesc* filter_diff = ctx->OutputTensorDesc("filter_diff", 0); + *filter_diff->mut_data_type() = x.data_type(); + return Maybe::Ok(); +} + +/* static */ Maybe ConvBiasGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); + user_op::TensorDesc* bias_diff = ctx->OutputTensorDesc("bias_diff", 0); + + int32_t num_spatial_dims = ctx->Attr("num_spatial_dims"); + std::string data_format = ctx->Attr("data_format"); + + CHECK_GE_OR_RETURN(num_spatial_dims, 1); + CHECK_LE_OR_RETURN(num_spatial_dims, 3); + CHECK_EQ_OR_RETURN(dy.shape().NumAxes(), num_spatial_dims + 2); + if (data_format == "channels_first") { + *bias_diff->mut_shape() = Shape({dy.shape().At(1)}); + } else if (data_format == "channels_last") { + *bias_diff->mut_shape() = Shape({dy.shape().At(dy.shape().NumAxes() - 1)}); + } else { + OF_UNIMPLEMENTED(); + } + return Maybe::Ok(); +} + +/*static*/ Maybe ConvBiasGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ConvBiasGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), 0) + .PartialSum(user_op::OpArg("bias_diff", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ConvBiasGradOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + std::string data_format = conf.attr("data_format"); + if (data_format == "channels_first" || data_format == "channels_last") { + return Maybe::Ok(); + } + return oneflow::Error::CheckFailedError() << "Illegal value for " << conf.op_type_name() << " op " + << conf.op_name() << ": data_format:" << data_format; +} + +/* static */ Maybe ConvBiasGradOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); + user_op::TensorDesc* bias_diff = ctx->OutputTensorDesc("bias_diff", 0); + *bias_diff->mut_data_type() = dy.data_type(); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("conv1d").SetGenBackwardOpConfFn(GenerateBackwardOpConf4Conv); REGISTER_USER_OP_GRAD("conv2d").SetGenBackwardOpConfFn(GenerateBackwardOpConf4Conv); REGISTER_USER_OP_GRAD("conv3d").SetGenBackwardOpConfFn(GenerateBackwardOpConf4Conv); -REGISTER_USER_OP("conv_data_grad") - .Input("dy") - .Input("filter") - .Input("x_like") - .OptionalInput("_add_to_output") - .Output("dx") - .Attr("num_spatial_dims") - .Attr>("padding_before") - .Attr("data_format") - .Attr>("kernel_size") - .Attr>("strides") - .Attr>("dilation_rate") - .Attr("groups") - .SetCheckAttrFn(CheckAttr<0>) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); - const user_op::TensorDesc& x_like = ctx->InputTensorDesc("x_like", 0); - const int32_t num_spatial_dims = ctx->Attr("num_spatial_dims"); - CHECK_GE_OR_RETURN(num_spatial_dims, 1); - CHECK_LE_OR_RETURN(num_spatial_dims, 3); - CHECK_EQ_OR_RETURN(dy.shape().NumAxes(), num_spatial_dims + 2); - CHECK_EQ_OR_RETURN(x_like.shape().NumAxes(), num_spatial_dims + 2); - if (ctx->has_input("_add_to_output", 0)) { - const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); - CHECK_EQ_OR_RETURN(add_to_output.shape(), x_like.shape()); - } - *ctx->OutputShape("dx", 0) = ctx->InputShape("x_like", 0); - *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x_like", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - std::vector split_args; - split_args.emplace_back("dy", 0); - split_args.emplace_back("x_like", 0); - split_args.emplace_back("dx", 0); - if (ctx->user_op_conf().has_input("_add_to_output", 0)) { - split_args.emplace_back("_add_to_output", 0); - } - ctx->NewBuilder().Split(split_args, 0).Broadcast(user_op::OpArg("filter", 0)).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); - const user_op::TensorDesc& x_like = ctx->InputTensorDesc("x_like", 0); - CHECK_EQ_OR_RETURN(x_like.data_type(), dy.data_type()); - if (ctx->has_input("_add_to_output", 0)) { - const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); - CHECK_EQ_OR_RETURN(add_to_output.data_type(), x_like.data_type()); - } - *ctx->OutputDType("dx", 0) = ctx->InputDType("x_like", 0); - return Maybe::Ok(); - }); - -REGISTER_USER_OP("conv_filter_grad") - .Input("dy") - .Input("x") - .Output("filter_diff") - .Attr("num_spatial_dims") - .Attr>("padding_before") - .Attr("data_format") - .Attr>("kernel_size") - .Attr>("strides") - .Attr>("dilation_rate") - .Attr("groups") - .SetCheckAttrFn(CheckAttr<0>) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - - const int32_t num_spatial_dims = ctx->Attr("num_spatial_dims"); - const int32_t groups = ctx->Attr("groups"); - const std::string& data_format = ctx->Attr("data_format"); - const std::vector kernel_size = ctx->Attr>("kernel_size"); - - CHECK_GE_OR_RETURN(num_spatial_dims, 1); - CHECK_LE_OR_RETURN(num_spatial_dims, 3); - CHECK_EQ_OR_RETURN(dy.shape().NumAxes(), num_spatial_dims + 2); - CHECK_EQ_OR_RETURN(x.shape().NumAxes(), num_spatial_dims + 2); - CHECK_GT_OR_RETURN(groups, 0); - - DimVector filter_diff_dim_vec; - if (data_format == "channels_first") { - CHECK_LE_OR_RETURN(groups, x.shape().At(1)); - CHECK_LE_OR_RETURN(groups, dy.shape().At(1)); - CHECK_EQ_OR_RETURN(x.shape().At(1) % groups, 0); - CHECK_EQ_OR_RETURN(dy.shape().At(1) % groups, 0); - filter_diff_dim_vec.emplace_back(dy.shape().At(1)); - filter_diff_dim_vec.emplace_back(x.shape().At(1) / groups); - filter_diff_dim_vec.insert(filter_diff_dim_vec.end(), kernel_size.cbegin(), - kernel_size.cend()); - } else { - CHECK_EQ_OR_RETURN("channels_last", data_format); - CHECK_EQ_OR_RETURN(groups, 1); - filter_diff_dim_vec.emplace_back(dy.shape().dim_vec().back()); - filter_diff_dim_vec.insert(filter_diff_dim_vec.end(), kernel_size.cbegin(), - kernel_size.cend()); - filter_diff_dim_vec.emplace_back(x.shape().dim_vec().back() / groups); - } - - user_op::TensorDesc* filter_diff = ctx->OutputTensorDesc("filter_diff", 0); - *filter_diff->mut_shape() = Shape(filter_diff_dim_vec); - filter_diff->set_is_dynamic(false); - - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), 0) - .Split(user_op::OpArg("x", 0), 0) - .PartialSum(user_op::OpArg("filter_diff", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - CHECK_EQ_OR_RETURN(x.data_type(), dy.data_type()); - user_op::TensorDesc* filter_diff = ctx->OutputTensorDesc("filter_diff", 0); - *filter_diff->mut_data_type() = x.data_type(); - return Maybe::Ok(); - }); - -REGISTER_USER_OP("conv_bias_grad") - .Input("dy") - .Output("bias_diff") - .Attr("data_format") - .Attr("num_spatial_dims") - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& def, - const user_op::UserOpConfWrapper& conf) -> Maybe { - std::string data_format = conf.attr("data_format"); - if (data_format == "channels_first" || data_format == "channels_last") { - return Maybe::Ok(); - } - return oneflow::Error::CheckFailedError() - << "Illegal value for " << conf.op_type_name() << " op " << conf.op_name() - << ": data_format:" << data_format; - }) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); - user_op::TensorDesc* bias_diff = ctx->OutputTensorDesc("bias_diff", 0); - - int32_t num_spatial_dims = ctx->Attr("num_spatial_dims"); - std::string data_format = ctx->Attr("data_format"); - - CHECK_GE_OR_RETURN(num_spatial_dims, 1); - CHECK_LE_OR_RETURN(num_spatial_dims, 3); - CHECK_EQ_OR_RETURN(dy.shape().NumAxes(), num_spatial_dims + 2); - if (data_format == "channels_first") { - *bias_diff->mut_shape() = Shape({dy.shape().At(1)}); - } else if (data_format == "channels_last") { - *bias_diff->mut_shape() = Shape({dy.shape().At(dy.shape().NumAxes() - 1)}); - } else { - OF_UNIMPLEMENTED(); - } - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), 0) - .PartialSum(user_op::OpArg("bias_diff", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); - user_op::TensorDesc* bias_diff = ctx->OutputTensorDesc("bias_diff", 0); - *bias_diff->mut_data_type() = dy.data_type(); - return Maybe::Ok(); - }); - } // namespace oneflow diff --git a/oneflow/user/ops/copy_op.cpp b/oneflow/user/ops/copy_op.cpp index 774cd3c184a..298a8ab7d2e 100644 --- a/oneflow/user/ops/copy_op.cpp +++ b/oneflow/user/ops/copy_op.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/framework/device.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -31,44 +32,41 @@ Maybe> MakeOpDevice(const Symbol& in_device, } } -std::function>(user_op::DeviceInferContext* ctx)> GetDeviceInferFn() { - std::function>(user_op::DeviceInferContext * ctx)> fn = - [](user_op::DeviceInferContext* ctx) -> Maybe> { - Symbol out_device = - JUST(Device::New(ctx->Attr("device_type"), ctx->Attr("device_id"))); - *ctx->OutputTensorDevice4ArgNameAndIndex("out", 0) = out_device; - const Symbol& in_device = ctx->InputTensorDevice4ArgNameAndIndex("in", 0); - return MakeOpDevice(in_device, out_device); - }; - return fn; +} // namespace + +/* static */ Maybe CopyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); } -REGISTER_USER_OP("copy") - .Input("in") - .Output("out") - .Attr("device_type") - .Attr("device_id") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDeviceInferFn(GetDeviceInferFn()) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto& inputs = ctx->inputs(); - CHECK_EQ_OR_RETURN(inputs.size(), 1); - const auto& input = - ctx->LogicalTensorDesc4InputArgNameAndIndex(inputs[0].first, inputs[0].second); - for (int64_t axis = 0; axis < input.shape().NumAxes(); ++axis) { - ctx->NewBuilder().Split(inputs, axis).Split(ctx->outputs(), axis).Build(); - } - ctx->NewBuilder().PartialSum(inputs).PartialSum(ctx->outputs()).Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe CopyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CopyOp::GetSbp(user_op::SbpContext* ctx) { + const auto& inputs = ctx->inputs(); + CHECK_EQ_OR_RETURN(inputs.size(), 1); + const auto& input = + ctx->LogicalTensorDesc4InputArgNameAndIndex(inputs[0].first, inputs[0].second); + for (int64_t axis = 0; axis < input.shape().NumAxes(); ++axis) { + ctx->NewBuilder().Split(inputs, axis).Split(ctx->outputs(), axis).Build(); + } + ctx->NewBuilder().PartialSum(inputs).PartialSum(ctx->outputs()).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe CopyOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> CopyOp::InferDevice(user_op::DeviceInferContext* ctx) { + Symbol out_device = + JUST(Device::New(ctx->Attr("device_type"), ctx->Attr("device_id"))); + *ctx->OutputTensorDevice4ArgNameAndIndex("out", 0) = out_device; + const Symbol& in_device = ctx->InputTensorDevice4ArgNameAndIndex("in", 0); + return MakeOpDevice(in_device, out_device); +} -} // namespace } // namespace oneflow diff --git a/oneflow/user/ops/count_not_finite_op.cpp b/oneflow/user/ops/count_not_finite_op.cpp index ba4ff545f3e..20e752a0b2c 100644 --- a/oneflow/user/ops/count_not_finite_op.cpp +++ b/oneflow/user/ops/count_not_finite_op.cpp @@ -14,62 +14,71 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("count_not_finite") - .Input("x") - .Output("y") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - *y_desc->mut_shape() = Shape({1}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .PartialSum(user_op::OpArg("y", 0)) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - *y_desc->mut_data_type() = DataType::kInt64; - return Maybe::Ok(); - }); +/* static */ Maybe CountNotFiniteOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + *y_desc->mut_shape() = Shape({1}); + return Maybe::Ok(); +} -REGISTER_NO_GRAD_USER_OP("multi_count_not_finite") - .InputWithMinimum("x", 1) - .Output("y") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - *y_desc->mut_shape() = Shape({1}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - int64_t min_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape().NumAxes(); - for (int64_t i = 1; i < ctx->user_op_conf().input_size("x"); ++i) { - min_num_axes = std::min( - min_num_axes, ctx->LogicalTensorDesc4InputArgNameAndIndex("x", i).shape().NumAxes()); - } - for (int64_t i = 0; i < min_num_axes; ++i) { - ctx->NewBuilder().Split(ctx->inputs(), i).PartialSum(user_op::OpArg("y", 0)).Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& first_x_desc = ctx->InputTensorDesc("x", 0); - for (const auto& in_arg_pair : ctx->inputs()) { - const user_op::TensorDesc& x_desc = - ctx->InputTensorDesc(in_arg_pair.first, in_arg_pair.second); - CHECK_EQ_OR_RETURN(x_desc.data_type(), first_x_desc.data_type()); - } - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - *y_desc->mut_data_type() = DataType::kInt64; - return Maybe::Ok(); - }); +/*static*/ Maybe CountNotFiniteOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CountNotFiniteOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).PartialSum(user_op::OpArg("y", 0)).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe CountNotFiniteOp::InferDataType(user_op::InferContext* ctx) { + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + *y_desc->mut_data_type() = DataType::kInt64; + return Maybe::Ok(); +} + +/* static */ Maybe MultiCountNotFiniteOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + *y_desc->mut_shape() = Shape({1}); + return Maybe::Ok(); +} + +/*static*/ Maybe MultiCountNotFiniteOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe MultiCountNotFiniteOp::GetSbp(user_op::SbpContext* ctx) { + int64_t min_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape().NumAxes(); + for (int64_t i = 1; i < ctx->user_op_conf().input_size("x"); ++i) { + min_num_axes = std::min(min_num_axes, + ctx->LogicalTensorDesc4InputArgNameAndIndex("x", i).shape().NumAxes()); + } + for (int64_t i = 0; i < min_num_axes; ++i) { + ctx->NewBuilder().Split(ctx->inputs(), i).PartialSum(user_op::OpArg("y", 0)).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe MultiCountNotFiniteOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& first_x_desc = ctx->InputTensorDesc("x", 0); + for (const auto& in_arg_pair : ctx->inputs()) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc(in_arg_pair.first, in_arg_pair.second); + CHECK_EQ_OR_RETURN(x_desc.data_type(), first_x_desc.data_type()); + } + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + *y_desc->mut_data_type() = DataType::kInt64; + return Maybe::Ok(); +} + +/*static*/ Maybe MultiCountNotFiniteOp::CheckAttr(const user_op::UserOpDefWrapper&, + const user_op::UserOpConfWrapper& op_conf) { + CHECK_OR_RETURN(op_conf.input_size("x") >= 1); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/ctc_loss_op.cpp b/oneflow/user/ops/ctc_loss_op.cpp index feaf6631e55..b8dee1ad9cc 100644 --- a/oneflow/user/ops/ctc_loss_op.cpp +++ b/oneflow/user/ops/ctc_loss_op.cpp @@ -14,105 +14,126 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("ctc_loss") - .Input("log_probs") - .Input("targets") - .Input("input_lengths") - .Input("target_lengths") - .Output("loss") - .Output("alpha") // 'alpha' is just for compute log_probs's grad, alpha's grad will be ignored - .Attr("max_target_length") - .Attr("blank") - .Attr("zero_infinity") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& log_probs = ctx->InputTensorDesc("log_probs", 0); - const user_op::TensorDesc& targets = ctx->InputTensorDesc("targets", 0); - const user_op::TensorDesc& input_lengths = ctx->InputTensorDesc("input_lengths", 0); - const user_op::TensorDesc& target_lengths = ctx->InputTensorDesc("target_lengths", 0); - const int64_t batch_size = log_probs.shape().At(1); - const int64_t max_target_length = ctx->Attr("max_target_length"); - if (targets.shape().NumAxes() == 2) { - CHECK_EQ_OR_RETURN(targets.shape().At(0), batch_size); - CHECK_GE_OR_RETURN(targets.shape().At(1), max_target_length); - } - CHECK_EQ_OR_RETURN(input_lengths.shape().At(0), batch_size); - CHECK_EQ_OR_RETURN(target_lengths.shape().At(0), batch_size); - CHECK_GE_OR_RETURN(ctx->Attr("blank"), 0); - CHECK_LT_OR_RETURN(ctx->Attr("blank"), log_probs.shape().At(2)); - - *ctx->OutputShape("loss", 0) = Shape({batch_size}); - *ctx->OutputShape("alpha", 0) = - Shape({batch_size, log_probs.shape().At(0), 2 * max_target_length + 1}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("log_probs", 0), 1) // `log_probs` batch axis is 1 - .Split(user_op::OpArg("targets", 0), 0) - .Split(user_op::OpArg("input_lengths", 0), 0) - .Split(user_op::OpArg("target_lengths", 0), 0) - .Split(user_op::OpArg("loss", 0), 0) - .Split(user_op::OpArg("alpha", 0), 0) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("loss", 0) = ctx->InputDType("log_probs", 0); - *ctx->OutputDType("alpha", 0) = ctx->InputDType("log_probs", 0); - return Maybe::Ok(); - }); +/* static */ Maybe CtcLossOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& log_probs = ctx->InputTensorDesc("log_probs", 0); + const user_op::TensorDesc& targets = ctx->InputTensorDesc("targets", 0); + const user_op::TensorDesc& input_lengths = ctx->InputTensorDesc("input_lengths", 0); + const user_op::TensorDesc& target_lengths = ctx->InputTensorDesc("target_lengths", 0); + const int64_t batch_size = log_probs.shape().At(1); + const int64_t max_target_length = ctx->Attr("max_target_length"); + if (targets.shape().NumAxes() == 2) { + CHECK_EQ_OR_RETURN(targets.shape().At(0), batch_size); + CHECK_GE_OR_RETURN(targets.shape().At(1), max_target_length); + } + CHECK_EQ_OR_RETURN(input_lengths.shape().At(0), batch_size); + CHECK_EQ_OR_RETURN(target_lengths.shape().At(0), batch_size); + CHECK_GE_OR_RETURN(ctx->Attr("blank"), 0); + CHECK_LT_OR_RETURN(ctx->Attr("blank"), log_probs.shape().At(2)); -REGISTER_USER_OP("ctc_loss_grad") - .Input("grad_out") - .Input("log_probs") - .Input("targets") - .Input("input_lengths") - .Input("target_lengths") - .Input("loss") - .Input("alpha") - .Output("grad") - .Attr("max_target_length") - .Attr("blank") - .Attr("zero_infinity") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& log_probs = ctx->InputTensorDesc("log_probs", 0); - const user_op::TensorDesc& targets = ctx->InputTensorDesc("targets", 0); - const user_op::TensorDesc& input_lengths = ctx->InputTensorDesc("input_lengths", 0); - const user_op::TensorDesc& target_lengths = ctx->InputTensorDesc("target_lengths", 0); - const int64_t batch_size = log_probs.shape().At(1); - const int64_t max_target_length = ctx->Attr("max_target_length"); - if (targets.shape().NumAxes() == 2) { - CHECK_EQ_OR_RETURN(targets.shape().At(0), batch_size); - CHECK_GE_OR_RETURN(targets.shape().At(1), max_target_length); - } - CHECK_EQ_OR_RETURN(input_lengths.shape().At(0), batch_size); - CHECK_EQ_OR_RETURN(target_lengths.shape().At(0), batch_size); - CHECK_GE_OR_RETURN(ctx->Attr("blank"), 0); - CHECK_LT_OR_RETURN(ctx->Attr("blank"), log_probs.shape().At(2)); - - *ctx->OutputShape("grad", 0) = log_probs.shape(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("grad_out", 0), 0) - .Split(user_op::OpArg("log_probs", 0), 1) // `log_probs` batch axis is 1 - .Split(user_op::OpArg("targets", 0), 0) - .Split(user_op::OpArg("input_lengths", 0), 0) - .Split(user_op::OpArg("target_lengths", 0), 0) - .Split(user_op::OpArg("loss", 0), 0) - .Split(user_op::OpArg("alpha", 0), 0) - .Split(user_op::OpArg("grad", 0), 1) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("grad", 0) = ctx->InputDType("log_probs", 0); - return Maybe::Ok(); - }); + *ctx->OutputShape("loss", 0) = Shape({batch_size}); + *ctx->OutputShape("alpha", 0) = + Shape({batch_size, log_probs.shape().At(0), 2 * max_target_length + 1}); + return Maybe::Ok(); +} + +/*static*/ Maybe CtcLossOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CtcLossOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("log_probs", 0), 1) // `log_probs` batch axis is 1 + .Split(user_op::OpArg("targets", 0), 0) + .Split(user_op::OpArg("input_lengths", 0), 0) + .Split(user_op::OpArg("target_lengths", 0), 0) + .Split(user_op::OpArg("loss", 0), 0) + .Split(user_op::OpArg("alpha", 0), 0) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe CtcLossOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("loss", 0) = ctx->InputDType("log_probs", 0); + *ctx->OutputDType("alpha", 0) = ctx->InputDType("log_probs", 0); + return Maybe::Ok(); +} + +/* static */ Maybe CtcLossGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& log_probs = ctx->InputTensorDesc("log_probs", 0); + const user_op::TensorDesc& targets = ctx->InputTensorDesc("targets", 0); + const user_op::TensorDesc& input_lengths = ctx->InputTensorDesc("input_lengths", 0); + const user_op::TensorDesc& target_lengths = ctx->InputTensorDesc("target_lengths", 0); + const int64_t batch_size = log_probs.shape().At(1); + const int64_t max_target_length = ctx->Attr("max_target_length"); + if (targets.shape().NumAxes() == 2) { + CHECK_EQ_OR_RETURN(targets.shape().At(0), batch_size); + CHECK_GE_OR_RETURN(targets.shape().At(1), max_target_length); + } + CHECK_EQ_OR_RETURN(input_lengths.shape().At(0), batch_size); + CHECK_EQ_OR_RETURN(target_lengths.shape().At(0), batch_size); + CHECK_GE_OR_RETURN(ctx->Attr("blank"), 0); + CHECK_LT_OR_RETURN(ctx->Attr("blank"), log_probs.shape().At(2)); + + *ctx->OutputShape("grad", 0) = log_probs.shape(); + return Maybe::Ok(); +} + +/*static*/ Maybe CtcLossGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CtcLossGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("grad_out", 0), 0) + .Split(user_op::OpArg("log_probs", 0), 1) // `log_probs` batch axis is 1 + .Split(user_op::OpArg("targets", 0), 0) + .Split(user_op::OpArg("input_lengths", 0), 0) + .Split(user_op::OpArg("target_lengths", 0), 0) + .Split(user_op::OpArg("loss", 0), 0) + .Split(user_op::OpArg("alpha", 0), 0) + .Split(user_op::OpArg("grad", 0), 1) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe CtcLossGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("grad", 0) = ctx->InputDType("log_probs", 0); + return Maybe::Ok(); +} + +/* static */ Maybe CtcGreedyDecoderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& log_probs = ctx->InputTensorDesc("log_probs", 0); + const user_op::TensorDesc& input_lengths = ctx->InputTensorDesc("input_lengths", 0); + const int64_t batch_size = log_probs.shape().At(1); + CHECK_EQ_OR_RETURN(batch_size, input_lengths.shape().At(0)); + *ctx->OutputShape("decoded", 0) = Shape({batch_size, log_probs.shape().At(0)}); + *ctx->OutputShape("neg_sum_logits", 0) = Shape({batch_size, 1}); + return Maybe::Ok(); +} + +/*static*/ Maybe CtcGreedyDecoderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CtcGreedyDecoderOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("log_probs", 0), 1) // `log_probs` batch axis is 1 + .Split(user_op::OpArg("input_lengths", 0), 0) + .Split(user_op::OpArg("decoded", 0), 0) + .Split(user_op::OpArg("neg_sum_logits", 0), 0) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe CtcGreedyDecoderOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("decoded", 0) = ctx->InputDType("input_lengths", 0); + *ctx->OutputDType("neg_sum_logits", 0) = ctx->InputDType("log_probs", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("ctc_loss") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { @@ -139,34 +160,4 @@ REGISTER_USER_OP_GRAD("ctc_loss") return Maybe::Ok(); }); -REGISTER_USER_OP("ctc_greedy_decoder") - .Input("log_probs") - .Input("input_lengths") - .Output("decoded") - .Output("neg_sum_logits") - .Attr("merge_repeated") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& log_probs = ctx->InputTensorDesc("log_probs", 0); - const user_op::TensorDesc& input_lengths = ctx->InputTensorDesc("input_lengths", 0); - const int64_t batch_size = log_probs.shape().At(1); - CHECK_EQ_OR_RETURN(batch_size, input_lengths.shape().At(0)); - *ctx->OutputShape("decoded", 0) = Shape({batch_size, log_probs.shape().At(0)}); - *ctx->OutputShape("neg_sum_logits", 0) = Shape({batch_size, 1}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("log_probs", 0), 1) // `log_probs` batch axis is 1 - .Split(user_op::OpArg("input_lengths", 0), 0) - .Split(user_op::OpArg("decoded", 0), 0) - .Split(user_op::OpArg("neg_sum_logits", 0), 0) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("decoded", 0) = ctx->InputDType("input_lengths", 0); - *ctx->OutputDType("neg_sum_logits", 0) = ctx->InputDType("log_probs", 0); - return Maybe::Ok(); - }); - } // namespace oneflow diff --git a/oneflow/user/ops/deconv_op.cpp b/oneflow/user/ops/deconv_op.cpp index ec105f3a91b..657098c6e7f 100644 --- a/oneflow/user/ops/deconv_op.cpp +++ b/oneflow/user/ops/deconv_op.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -81,7 +82,7 @@ Maybe InferTensorDesc4DeConv(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe InferDataType(user_op::InferContext* ctx) { +Maybe InferDataType_(user_op::InferContext* ctx) { *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } @@ -97,8 +98,8 @@ Maybe GetSbpSignatures4DeConv(user_op::SbpContext* ctx) { } template -Maybe CheckAttr(const user_op::UserOpDefWrapper& def, - const user_op::UserOpConfWrapper& conf) { +Maybe CheckAttr_(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { bool is_checked = true; std::stringstream err; err << "Illegal value for " << conf.op_type_name() << " op " << conf.op_name() << ": "; @@ -199,56 +200,68 @@ Maybe GenerateBackwardOpConf4DeConv(const user_op::UserOpWrapper& op, } // namespace -REGISTER_USER_OP("deconv1d") - .Input("in") - .Input("weight") - .Output("out") - .Attr("filters") - .Attr>("padding_before") - .Attr("data_format") - .Attr>("kernel_size") - .Attr>("output_padding") - .Attr>("strides") - .Attr>("dilation_rate") - .Attr("groups", 1) - .SetCheckAttrFn(CheckAttr<1>) - .SetTensorDescInferFn(InferTensorDesc4DeConv<1>) - .SetGetSbpFn(GetSbpSignatures4DeConv) - .SetDataTypeInferFn(InferDataType); - -REGISTER_USER_OP("deconv2d") - .Input("in") - .Input("weight") - .Output("out") - .Attr("filters") - .Attr>("padding_before") - .Attr("data_format") - .Attr>("kernel_size") - .Attr>("output_padding") - .Attr>("strides") - .Attr>("dilation_rate") - .Attr("groups", 1) - .SetCheckAttrFn(CheckAttr<2>) - .SetTensorDescInferFn(InferTensorDesc4DeConv<2>) - .SetGetSbpFn(GetSbpSignatures4DeConv) - .SetDataTypeInferFn(InferDataType); - -REGISTER_USER_OP("deconv3d") - .Input("in") - .Input("weight") - .Output("out") - .Attr("filters") - .Attr>("padding_before") - .Attr("data_format") - .Attr>("kernel_size") - .Attr>("output_padding") - .Attr>("strides") - .Attr>("dilation_rate") - .Attr("groups", 1) - .SetCheckAttrFn(CheckAttr<3>) - .SetTensorDescInferFn(InferTensorDesc4DeConv<3>) - .SetDataTypeInferFn(InferDataType) - .SetGetSbpFn(GetSbpSignatures4DeConv); +/* static */ Maybe Deconv1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDesc4DeConv<1>(ctx); +} + +/*static*/ Maybe Deconv1DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe Deconv1DOp::GetSbp(user_op::SbpContext* ctx) { + return GetSbpSignatures4DeConv(ctx); +} + +/* static */ Maybe Deconv1DOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + return CheckAttr_<1>(def, conf); +} + +/* static */ Maybe Deconv1DOp::InferDataType(user_op::InferContext* ctx) { + return InferDataType_(ctx); +} + +/* static */ Maybe Deconv2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDesc4DeConv<2>(ctx); +} + +/*static*/ Maybe Deconv2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe Deconv2DOp::GetSbp(user_op::SbpContext* ctx) { + return GetSbpSignatures4DeConv(ctx); +} + +/* static */ Maybe Deconv2DOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + return CheckAttr_<2>(def, conf); +} + +/* static */ Maybe Deconv2DOp::InferDataType(user_op::InferContext* ctx) { + return InferDataType_(ctx); +} + +/* static */ Maybe Deconv3DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDesc4DeConv<3>(ctx); +} + +/*static*/ Maybe Deconv3DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe Deconv3DOp::GetSbp(user_op::SbpContext* ctx) { + return GetSbpSignatures4DeConv(ctx); +} + +/* static */ Maybe Deconv3DOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + return CheckAttr_<3>(def, conf); +} + +/* static */ Maybe Deconv3DOp::InferDataType(user_op::InferContext* ctx) { + return InferDataType_(ctx); +} REGISTER_USER_OP_GRAD("deconv1d").SetGenBackwardOpConfFn(GenerateBackwardOpConf4DeConv); REGISTER_USER_OP_GRAD("deconv2d").SetGenBackwardOpConfFn(GenerateBackwardOpConf4DeConv); diff --git a/oneflow/user/ops/diag_op.cpp b/oneflow/user/ops/diag_op.cpp index 616fe3eadf2..3ea7d0ffd1d 100644 --- a/oneflow/user/ops/diag_op.cpp +++ b/oneflow/user/ops/diag_op.cpp @@ -14,69 +14,73 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("diag") - .Input("in") - .Output("out") - .Attr("diagonal", 0) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - const int32_t diagonal = ctx->Attr("diagonal"); - const ShapeView& in_shape = in.shape(); - const int32_t in_dim = in_shape.NumAxes(); - CHECK_GE_OR_RETURN(in_dim, 1); - CHECK_LE_OR_RETURN(in_dim, 2); - - DimVector out_dim_vec = {0}; - if (in_dim == 1) { - int32_t out_tensor_size = in_shape.At(0) + std::abs(diagonal); - out_dim_vec[0] = out_tensor_size; - out_dim_vec.emplace_back(out_tensor_size); - } else { - if (diagonal >= 0) { - out_dim_vec[0] = std::min(in_shape.At(0), in_shape.At(1) - diagonal); - } else { - out_dim_vec[0] = std::min(in_shape.At(0) + diagonal, in_shape.At(1)); - } - CHECK_GT_OR_RETURN(out_dim_vec[0], 0); - } - - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - out_desc->set_is_dynamic(false); - *out_desc->mut_shape() = Shape(out_dim_vec); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); - -REGISTER_USER_OP("diag_grad") - .Input("dy") - .Input("in") - .Attr("diagonal", 0) - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - const Shape& in_shape = in.shape(); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); - *dx_desc->mut_shape() = Shape(in_shape.dim_vec()); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/* static */ Maybe DiagOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + const int32_t diagonal = ctx->Attr("diagonal"); + const ShapeView& in_shape = in.shape(); + const int32_t in_dim = in_shape.NumAxes(); + CHECK_GE_OR_RETURN(in_dim, 1); + CHECK_LE_OR_RETURN(in_dim, 2); + + DimVector out_dim_vec = {0}; + if (in_dim == 1) { + int32_t out_tensor_size = in_shape.At(0) + std::abs(diagonal); + out_dim_vec[0] = out_tensor_size; + out_dim_vec.emplace_back(out_tensor_size); + } else { + if (diagonal >= 0) { + out_dim_vec[0] = std::min(in_shape.At(0), in_shape.At(1) - diagonal); + } else { + out_dim_vec[0] = std::min(in_shape.At(0) + diagonal, in_shape.At(1)); + } + CHECK_GT_OR_RETURN(out_dim_vec[0], 0); + } + + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + out_desc->set_is_dynamic(false); + *out_desc->mut_shape() = Shape(out_dim_vec); + return Maybe::Ok(); +} + +/*static*/ Maybe DiagOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe DiagOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe DiagOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe DiagGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + const Shape& in_shape = in.shape(); + user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + *dx_desc->mut_shape() = Shape(in_shape.dim_vec()); + return Maybe::Ok(); +} + +/*static*/ Maybe DiagGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe DiagGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe DiagGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("diag").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { diff --git a/oneflow/user/ops/diagonal_op.cpp b/oneflow/user/ops/diagonal_op.cpp index c6d8e9537d6..c7bed93b172 100644 --- a/oneflow/user/ops/diagonal_op.cpp +++ b/oneflow/user/ops/diagonal_op.cpp @@ -14,65 +14,69 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("diagonal") - .Input("in") - .Output("out") - .Attr("offset", 0) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - const int32_t offset = ctx->Attr("offset"); - const ShapeView& in_shape = in.shape(); - const int32_t in_dim = in_shape.NumAxes(); - CHECK_GE_OR_RETURN(in_dim, 2); - - DimVector out_dim_vec = {}; - FOR_RANGE(int32_t, index, 2, in_dim) { out_dim_vec.push_back(in_shape.At(index)); } - int32_t last_dim = 0; - if (offset >= 0) { - last_dim = std::min(in_shape.At(0), in_shape.At(1) - offset); - } else { - last_dim = std::min(in_shape.At(0) + offset, in_shape.At(1)); - } - if (last_dim < 0) { last_dim = 0; } - out_dim_vec.push_back(last_dim); - - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - out_desc->set_is_dynamic(false); - *out_desc->mut_shape() = Shape(out_dim_vec); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/* static */ Maybe DiagonalOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + const int32_t offset = ctx->Attr("offset"); + const ShapeView& in_shape = in.shape(); + const int32_t in_dim = in_shape.NumAxes(); + CHECK_GE_OR_RETURN(in_dim, 2); -REGISTER_USER_OP("diagonal_grad") - .Input("dy") - .Input("in") - .Attr("offset", 0) - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - const Shape& in_shape = in.shape(); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); - *dx_desc->mut_shape() = Shape(in_shape.dim_vec()); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); + DimVector out_dim_vec = {}; + FOR_RANGE(int32_t, index, 2, in_dim) { out_dim_vec.push_back(in_shape.At(index)); } + int32_t last_dim = 0; + if (offset >= 0) { + last_dim = std::min(in_shape.At(0), in_shape.At(1) - offset); + } else { + last_dim = std::min(in_shape.At(0) + offset, in_shape.At(1)); + } + if (last_dim < 0) { last_dim = 0; } + out_dim_vec.push_back(last_dim); + + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + out_desc->set_is_dynamic(false); + *out_desc->mut_shape() = Shape(out_dim_vec); + return Maybe::Ok(); +} + +/*static*/ Maybe DiagonalOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe DiagonalOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe DiagonalOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe DiagonalGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + const Shape& in_shape = in.shape(); + user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + *dx_desc->mut_shape() = Shape(in_shape.dim_vec()); + return Maybe::Ok(); +} + +/*static*/ Maybe DiagonalGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe DiagonalGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe DiagonalGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("diagonal") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { diff --git a/oneflow/user/ops/dim_gather_op.cpp b/oneflow/user/ops/dim_gather_op.cpp index 9c490a97985..fa7bd0815a0 100644 --- a/oneflow/user/ops/dim_gather_op.cpp +++ b/oneflow/user/ops/dim_gather_op.cpp @@ -15,79 +15,80 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/dim_gather_kernel_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace user_op { -REGISTER_USER_OP("dim_gather") - .Input("input") - .Input("index") - .Output("output") - .Attr("dim") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const TensorDesc& in = ctx->InputTensorDesc("input", 0); - int64_t input_num_axes = in.shape().NumAxes(); - CHECK_GT_OR_RETURN(input_num_axes, 0); - CHECK_LE_OR_RETURN(input_num_axes, kDimGatherMaxDimCount); +/* static */ Maybe DimGatherOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("input", 0); + int64_t input_num_axes = in.shape().NumAxes(); + CHECK_GT_OR_RETURN(input_num_axes, 0); + CHECK_LE_OR_RETURN(input_num_axes, kDimGatherMaxDimCount); - const TensorDesc& index = ctx->InputTensorDesc("index", 0); - int64_t index_num_axes = index.shape().NumAxes(); + const user_op::TensorDesc& index = ctx->InputTensorDesc("index", 0); + int64_t index_num_axes = index.shape().NumAxes(); - const int32_t dim = ctx->Attr("dim"); - CHECK_GE_OR_RETURN(dim, 0); - CHECK_LT_OR_RETURN(dim, input_num_axes); - CHECK_EQ_OR_RETURN(input_num_axes, index_num_axes); + const int32_t dim = ctx->Attr("dim"); + CHECK_GE_OR_RETURN(dim, 0); + CHECK_LT_OR_RETURN(dim, input_num_axes); + CHECK_EQ_OR_RETURN(input_num_axes, index_num_axes); - CHECK_EQ_OR_RETURN(in.is_dynamic(), index.is_dynamic()); + CHECK_EQ_OR_RETURN(in.is_dynamic(), index.is_dynamic()); - user_op::TensorDesc* out = ctx->OutputTensorDesc("output", 0); - *out->mut_shape() = index.shape(); + user_op::TensorDesc* out = ctx->OutputTensorDesc("output", 0); + *out->mut_shape() = index.shape(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const TensorDesc& index = ctx->InputTensorDesc("index", 0); - CHECK_OR_RETURN(IsIndexDataType(index.data_type())); - const TensorDesc& in = ctx->InputTensorDesc("input", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("output", 0); - *out->mut_data_type() = in.data_type(); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("index", 0); - CHECK_OR_RETURN(indices_modifier != nullptr); - indices_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& index_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("index", 0); - int64_t index_num_axes = index_tensor.shape().NumAxes(); - const int32_t dim = ctx->Attr("dim"); - - FOR_RANGE(int64_t, i, 0, index_num_axes) { - if (i != dim) { - ctx->NewBuilder() - .Split(user_op::OpArg("index", 0), i) - .Split(user_op::OpArg("input", 0), i) - .Split(user_op::OpArg("output", 0), i) - .Build(); - } else if (i == dim) { - ctx->NewBuilder() - .Broadcast(user_op::OpArg("input", 0)) - .Split(user_op::OpArg("index", 0), i) - .Split(user_op::OpArg("output", 0), i) - .Build(); - } - } + return Maybe::Ok(); +} + +/*static*/ Maybe DimGatherOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe DimGatherOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& index_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("index", 0); + int64_t index_num_axes = index_tensor.shape().NumAxes(); + const int32_t dim = ctx->Attr("dim"); + + FOR_RANGE(int64_t, i, 0, index_num_axes) { + if (i != dim) { ctx->NewBuilder() - .PartialSum(user_op::OpArg("input", 0)) - .Broadcast(user_op::OpArg("index", 0)) - .PartialSum(user_op::OpArg("output", 0)) + .Split(user_op::OpArg("index", 0), i) + .Split(user_op::OpArg("input", 0), i) + .Split(user_op::OpArg("output", 0), i) .Build(); - return Maybe::Ok(); - }); + } else if (i == dim) { + ctx->NewBuilder() + .Broadcast(user_op::OpArg("input", 0)) + .Split(user_op::OpArg("index", 0), i) + .Split(user_op::OpArg("output", 0), i) + .Build(); + } + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("input", 0)) + .Broadcast(user_op::OpArg("index", 0)) + .PartialSum(user_op::OpArg("output", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe DimGatherOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("index", 0); + CHECK_OR_RETURN(indices_modifier != nullptr); + indices_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe DimGatherOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& index = ctx->InputTensorDesc("index", 0); + CHECK_OR_RETURN(IsIndexDataType(index.data_type())); + const user_op::TensorDesc& in = ctx->InputTensorDesc("input", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("output", 0); + *out->mut_data_type() = in.data_type(); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("dim_gather") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { @@ -113,6 +114,4 @@ REGISTER_USER_OP_GRAD("dim_gather") return Maybe::Ok(); }); -} // namespace user_op - } // namespace oneflow diff --git a/oneflow/user/ops/dim_scatter_ops.cpp b/oneflow/user/ops/dim_scatter_ops.cpp index 06d168fd178..c6f84a91faf 100644 --- a/oneflow/user/ops/dim_scatter_ops.cpp +++ b/oneflow/user/ops/dim_scatter_ops.cpp @@ -17,25 +17,25 @@ limitations under the License. #include "oneflow/core/common/maybe.h" #include "oneflow/core/framework/user_op_registry.h" #include "oneflow/user/kernels/dim_scatter_kernel_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace user_op { - namespace { Maybe InferTensorDesc(user_op::InferContext* ctx) { - const TensorDesc* input = + const user_op::TensorDesc* input = ctx->has_input("input", 0) ? &ctx->InputTensorDesc("input", 0) : nullptr; - const TensorDesc& index = ctx->InputTensorDesc("index", 0); - const TensorDesc* like = ctx->has_input("like", 0) ? &ctx->InputTensorDesc("like", 0) : nullptr; - const TensorDesc& src = ctx->InputTensorDesc("src", 0); + const user_op::TensorDesc& index = ctx->InputTensorDesc("index", 0); + const user_op::TensorDesc* like = + ctx->has_input("like", 0) ? &ctx->InputTensorDesc("like", 0) : nullptr; + const user_op::TensorDesc& src = ctx->InputTensorDesc("src", 0); int32_t dim = ctx->Attr("dim"); // check index.numaxes == src.num_axes == input/like.numaxes int64_t src_num_axes = src.shape().NumAxes(); CHECK_GT_OR_RETURN(src_num_axes, 0); - CHECK_LE_OR_RETURN(src_num_axes, kDimGatherMaxDimCount); + CHECK_LE_OR_RETURN(src_num_axes, user_op::kDimGatherMaxDimCount); int64_t index_num_axes = index.shape().NumAxes(); CHECK_EQ_OR_RETURN(src_num_axes, index_num_axes); @@ -71,8 +71,8 @@ Maybe InferTensorDesc(user_op::InferContext* ctx) { } Maybe InferScalarTensorDesc(user_op::InferContext* ctx) { - const TensorDesc& input = ctx->InputTensorDesc("input", 0); - const TensorDesc& index = ctx->InputTensorDesc("index", 0); + const user_op::TensorDesc& input = ctx->InputTensorDesc("input", 0); + const user_op::TensorDesc& index = ctx->InputTensorDesc("index", 0); int32_t dim = ctx->Attr("dim"); @@ -87,12 +87,12 @@ Maybe InferScalarTensorDesc(user_op::InferContext* ctx) { CHECK_LE_OR_RETURN(index.shape().At(i), input.shape().At(i)); } - TensorDesc* out = ctx->OutputTensorDesc("output", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("output", 0); *out->mut_shape() = input.shape(); return Maybe::Ok(); } -Maybe InputArgModifierFn(user_op::GetInputArgModifier GetInputArgModifierFn, +Maybe InputArgModifierFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("index", 0); CHECK(indices_modifier != nullptr); @@ -101,7 +101,7 @@ Maybe InputArgModifierFn(user_op::GetInputArgModifier GetInputArgModifierF return Maybe::Ok(); } -Maybe InputScalarArgModifierFn(user_op::GetInputArgModifier GetInputArgModifierFn, +Maybe InputScalarArgModifierFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("index", 0); CHECK(indices_modifier != nullptr); @@ -159,10 +159,10 @@ Maybe SetSbpScatter(user_op::SbpContext* ctx) { } Maybe InferDtype(user_op::InferContext* ctx) { - const TensorDesc& index = ctx->InputTensorDesc("index", 0); + const user_op::TensorDesc& index = ctx->InputTensorDesc("index", 0); CHECK_OR_RETURN(IsIndexDataType(index.data_type())); if (ctx->has_input("input", 0)) { - const TensorDesc& input = ctx->InputTensorDesc("input", 0); + const user_op::TensorDesc& input = ctx->InputTensorDesc("input", 0); CHECK_EQ_OR_RETURN(ctx->InputDType("input", 0), ctx->InputDType("src", 0)); } else { CHECK_EQ_OR_RETURN(ctx->InputDType("like", 0), ctx->InputDType("src", 0)); @@ -172,15 +172,15 @@ Maybe InferDtype(user_op::InferContext* ctx) { } Maybe InferScalarDtype(user_op::InferContext* ctx) { - const TensorDesc& index = ctx->InputTensorDesc("index", 0); + const user_op::TensorDesc& index = ctx->InputTensorDesc("index", 0); CHECK_OR_RETURN(IsIndexDataType(index.data_type())); *ctx->OutputDType("output", 0) = ctx->InputDType("input", 0); return Maybe::Ok(); } Maybe ScatterBackward(user_op::BackwardOpConfContext* ctx) { - const TensorDesc& src = ctx->FwOp().TensorDesc4ArgNameAndIndex("src", 0); - const TensorDesc& index = ctx->FwOp().TensorDesc4ArgNameAndIndex("index", 0); + const user_op::TensorDesc& src = ctx->FwOp().TensorDesc4ArgNameAndIndex("src", 0); + const user_op::TensorDesc& index = ctx->FwOp().TensorDesc4ArgNameAndIndex("index", 0); const int64_t ndim = src.shape().NumAxes(); FOR_RANGE(int64_t, i, 0, ndim) { @@ -221,41 +221,70 @@ Maybe ScatterBackward(user_op::BackwardOpConfContext* ctx) { } // namespace -#define REGISTER_SCATTER_LIKE_OP(optypename) \ - REGISTER_USER_OP(optypename) \ - .Input("like") \ - .Input("index") \ - .Input("src") \ - .Output("output") \ - .Attr("dim") \ - .SetTensorDescInferFn(InferTensorDesc) \ - .SetInputArgModifyFn(InputArgModifierFn) \ - .SetDataTypeInferFn(InferDtype) \ - .SetGetSbpFn(SetSbpLike) - -#define REGISTER_SCATTER_OP(optypename) \ - REGISTER_USER_OP(optypename) \ - .Input("input") \ - .Input("index") \ - .Input("src") \ - .Output("output") \ - .Attr("dim") \ - .SetTensorDescInferFn(InferTensorDesc) \ - .SetInputArgModifyFn(InputArgModifierFn) \ - .SetDataTypeInferFn(InferDtype) \ - .SetGetSbpFn(SetSbpScatter) - -#define REGISTER_SCATTER_SCALAR_OP(optypename) \ - REGISTER_USER_OP(optypename) \ - .Input("input") \ - .Input("index") \ - .Attr("src_scalar") \ - .Output("output") \ - .Attr("dim") \ - .SetTensorDescInferFn(InferScalarTensorDesc) \ - .SetInputArgModifyFn(InputScalarArgModifierFn) \ - .SetDataTypeInferFn(InferScalarDtype) \ - .SetGetSbpFn(SetSbpScatter) +/* static */ Maybe DimScatterAddLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDesc(ctx); +} + +/*static*/ Maybe DimScatterAddLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe DimScatterAddLikeOp::GetSbp(user_op::SbpContext* ctx) { + return SetSbpLike(ctx); +} + +/* static */ Maybe DimScatterAddLikeOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return InputArgModifierFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe DimScatterAddLikeOp::InferDataType(user_op::InferContext* ctx) { + return InferDtype(ctx); +} + +#define DEF_SCATTER_OP(op_class_name) \ + /* static */ Maybe op_class_name::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return InferTensorDesc(ctx); \ + } \ + \ + /*static*/ Maybe op_class_name::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + \ + /* static */ Maybe op_class_name::GetSbp(user_op::SbpContext* ctx) { \ + return SetSbpScatter(ctx); \ + } \ + \ + /* static */ Maybe op_class_name::ModifyInputArg( \ + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { \ + return InputArgModifierFn(GetInputArgModifierFn, conf); \ + } \ + \ + /* static */ Maybe op_class_name::InferDataType(user_op::InferContext* ctx) { \ + return InferDtype(ctx); \ + } + +#define DEF_SCATTER_SCALAR_OP(optypename) \ + /* static */ Maybe optypename::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return InferScalarTensorDesc(ctx); \ + } \ + \ + /*static*/ Maybe optypename::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + \ + /* static */ Maybe optypename::GetSbp(user_op::SbpContext* ctx) { \ + return SetSbpScatter(ctx); \ + } \ + \ + /* static */ Maybe optypename::ModifyInputArg( \ + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { \ + return InputScalarArgModifierFn(GetInputArgModifierFn, conf); \ + } \ + \ + /* static */ Maybe optypename::InferDataType(user_op::InferContext* ctx) { \ + return InferScalarDtype(ctx); \ + } #define REGISTER_SCATTER_GRAD(optypename) \ REGISTER_USER_OP_GRAD(optypename).SetBackwardOpConfGenFn(ScatterBackward); @@ -279,19 +308,17 @@ Maybe ScatterBackward(user_op::BackwardOpConfContext* ctx) { }); \ return Maybe::Ok(); \ }); +DEF_SCATTER_OP(DimScatterAddOp); +DEF_SCATTER_OP(DimScatterUpdateOp); +DEF_SCATTER_OP(DimScatterMulOp); -REGISTER_SCATTER_LIKE_OP("dim_scatter_add_like"); -REGISTER_SCATTER_OP("dim_scatter_add"); -REGISTER_SCATTER_OP("dim_scatter_update"); -REGISTER_SCATTER_OP("dim_scatter_mul"); - -REGISTER_SCATTER_SCALAR_OP("dim_scatter_update_scalar"); -REGISTER_SCATTER_SCALAR_OP("dim_scatter_add_scalar"); -REGISTER_SCATTER_SCALAR_OP("dim_scatter_mul_scalar"); +DEF_SCATTER_SCALAR_OP(DimScatterUpdateScalarOp); +DEF_SCATTER_SCALAR_OP(DimScatterAddScalarOp); +DEF_SCATTER_SCALAR_OP(DimScatterMulScalarOp); REGISTER_SCATTER_GRAD("dim_scatter_add"); REGISTER_SCATTER_GRAD("dim_scatter_update"); REGISTER_SCATTER_SCALAR_GRAD("dim_scatter_update_scalar"); -} // namespace user_op + } // namespace oneflow diff --git a/oneflow/user/ops/distributions/normal_op.cpp b/oneflow/user/ops/distributions/normal_op.cpp index 676f6428bf4..29adc2e0fbb 100644 --- a/oneflow/user/ops/distributions/normal_op.cpp +++ b/oneflow/user/ops/distributions/normal_op.cpp @@ -15,37 +15,36 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("normal") - .Output("out") - .SetOutputBufferNum(1) - .Attr("mean", 0) - .Attr("std", 1) - .Attr("seed") - .Attr("dtype") - .Attr("shape") - .Attr>("nd_sbp") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - Shape* out_shape = ctx->OutputShape("out", 0); - const Shape& shape = ctx->Attr("shape"); - *out_shape = shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - auto dtype = ctx->Attr("dtype"); - *ctx->OutputDType("out", 0) = dtype; - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - cfg::SbpParallel default_sbp; - default_sbp.mutable_broadcast_parallel(); - return user_op::InferNdSbp4SrcOp(ctx, default_sbp); - }); +/* static */ Maybe NormalOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + Shape* out_shape = ctx->OutputShape("out", 0); + const Shape& shape = ctx->Attr("shape"); + *out_shape = shape; + return Maybe::Ok(); +} + +/*static*/ Maybe NormalOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe NormalOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe NormalOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + cfg::SbpParallel default_sbp; + default_sbp.mutable_broadcast_parallel(); + return user_op::InferNdSbp4SrcOp(ctx, default_sbp); +} + +/* static */ Maybe NormalOp::InferDataType(user_op::InferContext* ctx) { + auto dtype = ctx->Attr("dtype"); + *ctx->OutputDType("out", 0) = dtype; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/distributions/uniform_int_op.cpp b/oneflow/user/ops/distributions/uniform_int_op.cpp index 60ee65c04f3..5c7bcfc7c57 100644 --- a/oneflow/user/ops/distributions/uniform_int_op.cpp +++ b/oneflow/user/ops/distributions/uniform_int_op.cpp @@ -14,41 +14,40 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("uniform_int") - .Output("out") - .SetOutputBufferNum(1) - .Attr("from", 0) - .Attr("to", 1) - .Attr("seed") - .Attr("dtype") - .Attr("shape") - .Attr>("nd_sbp") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - Shape* out_shape = ctx->OutputShape("out", 0); - const Shape& shape = ctx->Attr("shape"); - DimVector dim_vec; - if (shape.NumAxes() > 0) { - dim_vec.insert(dim_vec.end(), shape.dim_vec().cbegin(), shape.dim_vec().cend()); - } - *out_shape = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - auto dtype = ctx->Attr("dtype"); - *ctx->OutputDType("out", 0) = dtype; - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - cfg::SbpParallel default_sbp; - default_sbp.mutable_broadcast_parallel(); - return user_op::InferNdSbp4SrcOp(ctx, default_sbp); - }); +/* static */ Maybe UniformIntOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + Shape* out_shape = ctx->OutputShape("out", 0); + const Shape& shape = ctx->Attr("shape"); + DimVector dim_vec; + if (shape.NumAxes() > 0) { + dim_vec.insert(dim_vec.end(), shape.dim_vec().cbegin(), shape.dim_vec().cend()); + } + *out_shape = Shape(dim_vec); + return Maybe::Ok(); +} + +/*static*/ Maybe UniformIntOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe UniformIntOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe UniformIntOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + cfg::SbpParallel default_sbp; + default_sbp.mutable_broadcast_parallel(); + return user_op::InferNdSbp4SrcOp(ctx, default_sbp); +} + +/* static */ Maybe UniformIntOp::InferDataType(user_op::InferContext* ctx) { + auto dtype = ctx->Attr("dtype"); + *ctx->OutputDType("out", 0) = dtype; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/distributions/uniform_op.cpp b/oneflow/user/ops/distributions/uniform_op.cpp index 366c294780e..0e972755055 100644 --- a/oneflow/user/ops/distributions/uniform_op.cpp +++ b/oneflow/user/ops/distributions/uniform_op.cpp @@ -14,41 +14,40 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("uniform") - .Output("out") - .SetOutputBufferNum(1) - .Attr("from", 0) - .Attr("to", 1) - .Attr("seed") - .Attr("dtype") - .Attr("shape") - .Attr>("nd_sbp") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - Shape* out_shape = ctx->OutputShape("out", 0); - const Shape& shape = ctx->Attr("shape"); - DimVector dim_vec; - if (shape.NumAxes() > 0) { - dim_vec.insert(dim_vec.end(), shape.dim_vec().cbegin(), shape.dim_vec().cend()); - } - *out_shape = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - auto dtype = ctx->Attr("dtype"); - *ctx->OutputDType("out", 0) = dtype; - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - cfg::SbpParallel default_sbp; - default_sbp.mutable_broadcast_parallel(); - return user_op::InferNdSbp4SrcOp(ctx, default_sbp); - }); +/* static */ Maybe UniformOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + Shape* out_shape = ctx->OutputShape("out", 0); + const Shape& shape = ctx->Attr("shape"); + DimVector dim_vec; + if (shape.NumAxes() > 0) { + dim_vec.insert(dim_vec.end(), shape.dim_vec().cbegin(), shape.dim_vec().cend()); + } + *out_shape = Shape(dim_vec); + return Maybe::Ok(); +} + +/*static*/ Maybe UniformOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe UniformOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe UniformOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + cfg::SbpParallel default_sbp; + default_sbp.mutable_broadcast_parallel(); + return user_op::InferNdSbp4SrcOp(ctx, default_sbp); +} + +/* static */ Maybe UniformOp::InferDataType(user_op::InferContext* ctx) { + auto dtype = ctx->Attr("dtype"); + *ctx->OutputDType("out", 0) = dtype; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/dot_op.cpp b/oneflow/user/ops/dot_op.cpp index 369c510fe2e..a67361aef30 100644 --- a/oneflow/user/ops/dot_op.cpp +++ b/oneflow/user/ops/dot_op.cpp @@ -14,39 +14,40 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/* static */ Maybe DotOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& y = ctx->InputTensorDesc("y", 0); + CHECK_OR_RETURN(x.shape() == y.shape()) << "Input tensor shape is different"; + CHECK_OR_RETURN(x.shape().NumAxes() == 1) << "Input tensor is not 1D"; + *ctx->OutputShape("out", 0) = Shape({}); + return Maybe::Ok(); +} -REGISTER_USER_OP("dot") - .Input("x") - .Input("y") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& y = ctx->InputTensorDesc("y", 0); - CHECK_OR_RETURN(x.shape() == y.shape()) << "Input tensor shape is different"; - CHECK_OR_RETURN(x.shape().NumAxes() == 1) << "Input tensor is not 1D"; - *ctx->OutputShape("out", 0) = Shape({}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), 0) - .Split(user_op::OpArg("y", 0), 0) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); +/*static*/ Maybe DotOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& y = ctx->InputTensorDesc("y", 0); - CHECK_OR_RETURN(x.data_type() == y.data_type()) << "The input tensor type is different"; - *ctx->OutputDType("out", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/* static */ Maybe DotOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), 0) + .Split(user_op::OpArg("y", 0), 0) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + + return Maybe::Ok(); +} + +/* static */ Maybe DotOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& y = ctx->InputTensorDesc("y", 0); + CHECK_OR_RETURN(x.data_type() == y.data_type()) << "The input tensor type is different"; + *ctx->OutputDType("out", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("dot").SetGenBackwardOpConfFn( [](const user_op::UserOpWrapper& op, const user_op::AddOpFn& AddOp) -> Maybe { @@ -76,6 +77,4 @@ REGISTER_USER_OP_GRAD("dot").SetGenBackwardOpConfFn( return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/dropout_op.cpp b/oneflow/user/ops/dropout_op.cpp index 41cab7bd88e..20beb57083a 100644 --- a/oneflow/user/ops/dropout_op.cpp +++ b/oneflow/user/ops/dropout_op.cpp @@ -14,77 +14,112 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -REGISTER_USER_OP("dropout") - .Input("in") - .OptionalInput("_add_to_output") - .Output("out") - .Output("mask") - .Attr("rate") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - *ctx->OutputShape("out", 0) = in_shape; - *ctx->OutputShape("mask", 0) = in_shape; - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), axis).Build(); - } - return Maybe::Ok(); - }) - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& op_def, - const user_op::UserOpConfWrapper& op_conf) -> Maybe { - float rate = op_conf.attr("rate"); - CHECK_GE_OR_RETURN(rate, 0.0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - *ctx->OutputDType("mask", 0) = DataType::kInt8; - return Maybe::Ok(); - }); - -REGISTER_USER_OP("dropout_grad") - .Input("dy") - .Input("mask") - .Output("dx") - .Attr("scale") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - *ctx->OutputShape("dx", 0) = dy_shape; - *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("dy", 0); - CHECK_EQ_OR_RETURN(ctx->InputShape("mask", 0), dy_shape); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); - FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), axis) - .Split(user_op::OpArg("mask", 0), axis) - .Split(user_op::OpArg("dx", 0), axis) - .Build(); - } - return Maybe::Ok(); - }) - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& op_def, - const user_op::UserOpConfWrapper& op_conf) -> Maybe { - float scale = op_conf.attr("scale"); - CHECK_GT_OR_RETURN(scale, 1); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - CHECK_EQ_OR_RETURN(ctx->InputDType("mask", 0), DataType::kInt8); - return Maybe::Ok(); - }); +/* static */ Maybe DropoutOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + *ctx->OutputShape("out", 0) = in_shape; + *ctx->OutputShape("mask", 0) = in_shape; + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe DropoutOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe DropoutOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), axis).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe DropoutOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + float rate = conf.attr("rate"); + CHECK_GE_OR_RETURN(rate, 0.0); + return Maybe::Ok(); +} + +/* static */ Maybe DropoutOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + *ctx->OutputDType("mask", 0) = DataType::kInt8; + return Maybe::Ok(); +} + +/* static */ Maybe DropoutGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + *ctx->OutputShape("dx", 0) = dy_shape; + *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("dy", 0); + CHECK_EQ_OR_RETURN(ctx->InputShape("mask", 0), dy_shape); + return Maybe::Ok(); +} + +/*static*/ Maybe DropoutGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe DropoutGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); + FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), axis) + .Split(user_op::OpArg("mask", 0), axis) + .Split(user_op::OpArg("dx", 0), axis) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe DropoutGradOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + float scale = conf.attr("scale"); + CHECK_GT_OR_RETURN(scale, 1); + return Maybe::Ok(); +} + +/* static */ Maybe DropoutGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + CHECK_EQ_OR_RETURN(ctx->InputDType("mask", 0), DataType::kInt8); + return Maybe::Ok(); +} + +/* static */ Maybe RandomMaskLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe RandomMaskLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe RandomMaskLikeOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& like_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0); + FOR_RANGE(int64_t, axis, 0, like_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("like", 0), axis) + .Split(user_op::OpArg("out", 0), axis) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe RandomMaskLikeOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + float rate = conf.attr("rate"); + CHECK_GE_OR_RETURN(rate, 0); + CHECK_LT_OR_RETURN(rate, 1); + return Maybe::Ok(); +} + +/* static */ Maybe RandomMaskLikeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = DataType::kInt8; + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("dropout").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { @@ -106,38 +141,4 @@ REGISTER_USER_OP_GRAD("dropout").SetGenBackwardOpConfFn([](const user_op::UserOp return Maybe::Ok(); }); -REGISTER_NO_GRAD_USER_OP("random_mask_like") - .Input("like") - .Output("out") - .Attr("rate") - .Attr("seed") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& like_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0); - FOR_RANGE(int64_t, axis, 0, like_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("like", 0), axis) - .Split(user_op::OpArg("out", 0), axis) - .Build(); - } - return Maybe::Ok(); - }) - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& op_def, - const user_op::UserOpConfWrapper& op_conf) -> Maybe { - float rate = op_conf.attr("rate"); - CHECK_GE_OR_RETURN(rate, 0); - CHECK_LT_OR_RETURN(rate, 1); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = DataType::kInt8; - return Maybe::Ok(); - }); - -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/dynamic_loss_scale_schedule_op.cpp b/oneflow/user/ops/dynamic_loss_scale_schedule_op.cpp index e24277c4235..5745a6cdd25 100644 --- a/oneflow/user/ops/dynamic_loss_scale_schedule_op.cpp +++ b/oneflow/user/ops/dynamic_loss_scale_schedule_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -27,24 +28,27 @@ bool IsTensorWithType(const user_op::TensorDesc* desc, DataType data_type) { return desc->data_type() == data_type; } -Maybe InferTensorDesc(user_op::InferContext* ctx) { +} // namespace + +/* static */ Maybe DynamicLossScaleScheduleOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { CHECK_OR_RETURN(IsScalarTensor(&(ctx->InputTensorDesc("count_not_finite", 0)))); CHECK_OR_RETURN(IsScalarTensor(&(ctx->InputTensorDesc("loss_scale", 0)))); CHECK_OR_RETURN(IsScalarTensor(&(ctx->InputTensorDesc("good_step_counter", 0)))); return Maybe::Ok(); } -Maybe InferDataType(user_op::InferContext* ctx) { - CHECK_OR_RETURN( - IsTensorWithType(&(ctx->InputTensorDesc("count_not_finite", 0)), DataType::kInt64)); - CHECK_OR_RETURN(IsTensorWithType(&(ctx->InputTensorDesc("loss_scale", 0)), DataType::kFloat)); - CHECK_OR_RETURN( - IsTensorWithType(&(ctx->InputTensorDesc("good_step_counter", 0)), DataType::kInt64)); - return Maybe::Ok(); +/*static*/ Maybe DynamicLossScaleScheduleOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe DynamicLossScaleScheduleOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); } -Maybe InputArgModifierFn(const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) { +/* static */ Maybe DynamicLossScaleScheduleOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { user_op::InputArgModifier* loss_scale = GetInputArgModifierFn("loss_scale", 0); CHECK_OR_RETURN(loss_scale != nullptr); loss_scale->set_is_mutable(true); @@ -54,17 +58,13 @@ Maybe InputArgModifierFn(const user_op::GetInputArgModifier& GetInputArgMo return Maybe::Ok(); } -} // namespace - -REGISTER_USER_OP("dynamic_loss_scale_schedule") - .Input("count_not_finite") - .Input("loss_scale") - .Input("good_step_counter") - .Attr("increment_period", 2000) - .Attr("multiplier", 2.0) - .SetTensorDescInferFn(InferTensorDesc) - .SetInputArgModifyFn(InputArgModifierFn) - .SetDataTypeInferFn(InferDataType) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/* static */ Maybe DynamicLossScaleScheduleOp::InferDataType(user_op::InferContext* ctx) { + CHECK_OR_RETURN( + IsTensorWithType(&(ctx->InputTensorDesc("count_not_finite", 0)), DataType::kInt64)); + CHECK_OR_RETURN(IsTensorWithType(&(ctx->InputTensorDesc("loss_scale", 0)), DataType::kFloat)); + CHECK_OR_RETURN( + IsTensorWithType(&(ctx->InputTensorDesc("good_step_counter", 0)), DataType::kInt64)); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/eager_b_to_s_op.cpp b/oneflow/user/ops/eager_b_to_s_op.cpp index 542b96fa5df..88a0f4a82a0 100644 --- a/oneflow/user/ops/eager_b_to_s_op.cpp +++ b/oneflow/user/ops/eager_b_to_s_op.cpp @@ -19,12 +19,12 @@ limitations under the License. #include "oneflow/core/common/shape.h" #include "oneflow/core/framework/device.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -Maybe TensorDescInfer(user_op::InferContext* ctx) { +// Can only be called in mirrored TODO: move this comment to ods +/* static */ Maybe EagerBToSOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& shape = ctx->Attr("shape"); const std::string& out_parallel_conf_txt = ctx->Attr("out_parallel_conf"); const int64_t out_split_axis = ctx->Attr("out_split_axis"); @@ -40,27 +40,25 @@ Maybe TensorDescInfer(user_op::InferContext* ctx) { return Maybe::Ok(); } -} // namespace +/*static*/ Maybe EagerBToSOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe EagerBToSOp::GetSbp(user_op::SbpContext* ctx) { + return Error::TypeError() << "eager_b_to_s op doesn't support consistent tensor!"; +} + +/* static */ Maybe EagerBToSOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + return Error::TypeError() << "eager_b_to_s op doesn't support consistent tensor!"; +} -// Can only be called in mirrored -REGISTER_NO_GRAD_USER_OP("eager_b_to_s") - .Input("in") - .Output("out") - .Attr("out_split_axis", -1) - .Attr("in_parallel_conf") - .Attr("out_parallel_conf") - .Attr("shape") - .SetTensorDescInferFn(TensorDescInfer) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - return Error::TypeError() << "eager_b_to_s op doesn't support consistent tensor!"; - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - return Error::TypeError() << "eager_b_to_s op doesn't support consistent tensor!"; - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/* static */ Maybe EagerBToSOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> EagerBToSOp::InferDevice(user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} } // namespace oneflow diff --git a/oneflow/user/ops/eager_nccl_ops.cpp b/oneflow/user/ops/eager_nccl_ops.cpp index efbbedd538d..daa5bd045d8 100644 --- a/oneflow/user/ops/eager_nccl_ops.cpp +++ b/oneflow/user/ops/eager_nccl_ops.cpp @@ -18,210 +18,233 @@ limitations under the License. #include "oneflow/core/common/decorator.h" #include "oneflow/core/framework/device.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("eager_nccl_all_reduce") - .Input("in") - .Output("out") - .Attr("parallel_conf") - .Attr("async_launch", false) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&IsAsyncLaunched>) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); - -REGISTER_USER_OP("eager_nccl_broadcast") - .Input("in") - .Output("out") - .Attr("parallel_conf") - .Attr("root", 0) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("in", 0)) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), 0) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_USER_OP("eager_nccl_reduce") - .Input("in") - .Output("out") - .Attr("parallel_conf") - .Attr("root", 0) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - UNIMPLEMENTED_THEN_RETURN() << "consistent tensor are not supported"; - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_USER_OP("eager_nccl_reduce_scatter") - .Input("in") - .Output("out") - .Attr("parallel_conf") - .Attr("op_type", "sum") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetPhysicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - Shape* out_shape = ctx->OutputShape("out", 0); - const Shape& shape = ctx->InputShape("in", 0); - DimVector dim_vec; - if (shape.NumAxes() > 0) { - dim_vec.insert(dim_vec.end(), shape.dim_vec().cbegin(), shape.dim_vec().cend()); - } - const cfg::SbpParallel& out_sbp_para = ctx->SbpParallel4ArgNameAndIndex("out", 0); - const int64_t& parallel_num = ctx->parallel_ctx().parallel_num(); - if (parallel_num > 1) { - const int64_t& split_axis = out_sbp_para.split_parallel().axis(); - CHECK_LT_OR_RETURN(split_axis, dim_vec.size()); - BalancedSplitter bs(shape.At(split_axis), parallel_num); - dim_vec[split_axis] = bs.At(ctx->parallel_ctx().parallel_id()).size(); - } - *out_shape = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - cfg::NdSbp* in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); - CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); - for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { - CHECK_OR_RETURN(sbp_hint.has_partial_sum_parallel() || sbp_hint.has_broadcast_parallel()); - } - in_nd_sbp->clear_sbp_parallel(); - out_nd_sbp->clear_sbp_parallel(); - - // P2S or B2S - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); - in_nd_sbp->CopyFrom(in_dis_hint); - for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { - out_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(0); - } - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_USER_OP("eager_nccl_all_gather") - .Input("in") - .Output("out") - .Attr("parallel_conf") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - cfg::NdSbp* in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); - CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); - for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { - CHECK_OR_RETURN(sbp_hint.has_split_parallel()); - CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), 0); - } - - in_nd_sbp->clear_sbp_parallel(); - out_nd_sbp->clear_sbp_parallel(); - - // S(0)->B - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); - for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { - in_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(0); - out_nd_sbp->add_sbp_parallel()->mutable_broadcast_parallel(); - } - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_USER_OP("eager_nccl_s2s") - .Input("in") - .Output("out") - .Attr("in_split_axis", -1) - .Attr("out_split_axis", -1) - .Attr("parallel_conf") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const int64_t in_split_axis = ctx->user_op_conf().attr("in_split_axis"); - const int64_t out_split_axis = ctx->user_op_conf().attr("out_split_axis"); - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - cfg::NdSbp* in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); - CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); - for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { - CHECK_OR_RETURN(sbp_hint.has_split_parallel()); - CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), in_split_axis); - } - - in_nd_sbp->clear_sbp_parallel(); - out_nd_sbp->clear_sbp_parallel(); - - // S(in)->S(out) - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); - for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { - in_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(in_split_axis); - out_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(out_split_axis); - } - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/* static */ Maybe EagerNcclAllReduceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe EagerNcclAllReduceOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe EagerNcclAllReduceOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().PartialSum(user_op::OpArg("in", 0)).Broadcast(user_op::OpArg("out", 0)).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe EagerNcclAllReduceOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> EagerNcclAllReduceOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&IsAsyncLaunched>(ctx); +} + +/* static */ Maybe EagerNcclBroadcastOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe EagerNcclBroadcastOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe EagerNcclBroadcastOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().PartialSum(user_op::OpArg("in", 0)).Broadcast(user_op::OpArg("out", 0)).Build(); + ctx->NewBuilder().Broadcast(user_op::OpArg("in", 0)).Broadcast(user_op::OpArg("out", 0)).Build(); + ctx->NewBuilder().Split(user_op::OpArg("in", 0), 0).Broadcast(user_op::OpArg("out", 0)).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe EagerNcclBroadcastOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> EagerNcclBroadcastOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + +/* static */ Maybe EagerNcclReduceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe EagerNcclReduceOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe EagerNcclReduceOp::GetSbp(user_op::SbpContext* ctx) { + UNIMPLEMENTED_THEN_RETURN() << "consistent tensor are not supported"; +} + +/* static */ Maybe EagerNcclReduceOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> EagerNcclReduceOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + +/* static */ Maybe EagerNcclReduceScatterOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe EagerNcclReduceScatterOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + Shape* out_shape = ctx->OutputShape("out", 0); + const Shape& shape = ctx->InputShape("in", 0); + DimVector dim_vec; + if (shape.NumAxes() > 0) { + dim_vec.insert(dim_vec.end(), shape.dim_vec().cbegin(), shape.dim_vec().cend()); + } + const cfg::SbpParallel& out_sbp_para = ctx->SbpParallel4ArgNameAndIndex("out", 0); + const int64_t& parallel_num = ctx->parallel_ctx().parallel_num(); + if (parallel_num > 1) { + const int64_t& split_axis = out_sbp_para.split_parallel().axis(); + CHECK_LT_OR_RETURN(split_axis, dim_vec.size()); + BalancedSplitter bs(shape.At(split_axis), parallel_num); + dim_vec[split_axis] = bs.At(ctx->parallel_ctx().parallel_id()).size(); + } + *out_shape = Shape(dim_vec); + return Maybe::Ok(); +} + +/* static */ Maybe EagerNcclReduceScatterOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe EagerNcclReduceScatterOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + cfg::NdSbp* in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); + CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); + for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { + CHECK_OR_RETURN(sbp_hint.has_partial_sum_parallel() || sbp_hint.has_broadcast_parallel()); + } + in_nd_sbp->clear_sbp_parallel(); + out_nd_sbp->clear_sbp_parallel(); + + // P2S or B2S + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); + in_nd_sbp->CopyFrom(in_dis_hint); + for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { + out_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(0); + } + return Maybe::Ok(); +} + +/* static */ Maybe EagerNcclReduceScatterOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> EagerNcclReduceScatterOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + +/* static */ Maybe EagerNcclAllGatherOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe EagerNcclAllGatherOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe EagerNcclAllGatherOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe EagerNcclAllGatherOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + cfg::NdSbp* in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); + CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); + for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { + CHECK_OR_RETURN(sbp_hint.has_split_parallel()); + CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), 0); + } + + in_nd_sbp->clear_sbp_parallel(); + out_nd_sbp->clear_sbp_parallel(); + + // S(0)->B + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); + for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { + in_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(0); + out_nd_sbp->add_sbp_parallel()->mutable_broadcast_parallel(); + } + return Maybe::Ok(); +} + +/* static */ Maybe EagerNcclAllGatherOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> EagerNcclAllGatherOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + +/* static */ Maybe EagerNcclS2sOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe EagerNcclS2sOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe EagerNcclS2sOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + const int64_t in_split_axis = ctx->user_op_conf().attr("in_split_axis"); + const int64_t out_split_axis = ctx->user_op_conf().attr("out_split_axis"); + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + cfg::NdSbp* in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); + CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); + for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { + CHECK_OR_RETURN(sbp_hint.has_split_parallel()); + CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), in_split_axis); + } + + in_nd_sbp->clear_sbp_parallel(); + out_nd_sbp->clear_sbp_parallel(); + + // S(in)->S(out) + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); + for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { + in_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(in_split_axis); + out_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(out_split_axis); + } + return Maybe::Ok(); +} + +/* static */ Maybe EagerNcclS2sOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> EagerNcclS2sOp::InferDevice(user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + } // namespace oneflow diff --git a/oneflow/user/ops/eager_p_to_b_op.cpp b/oneflow/user/ops/eager_p_to_b_op.cpp index 6b30ac9cec8..e1809c23a8d 100644 --- a/oneflow/user/ops/eager_p_to_b_op.cpp +++ b/oneflow/user/ops/eager_p_to_b_op.cpp @@ -19,30 +19,34 @@ limitations under the License. #include "oneflow/core/common/shape.h" #include "oneflow/core/framework/device.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { - // Can only be called in mirrored -REGISTER_NO_GRAD_USER_OP("eager_p_to_b") - .Input("in") - .Output("out") - .Attr("in_parallel_conf") - .Attr("out_parallel_conf") - .Attr("shape") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - return Error::TypeError() << "eager_s_to_b op doesn't support consistent tensor!"; - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - return Error::TypeError() << "eager_s_to_b op doesn't support consistent tensor!"; - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/* static */ Maybe EagerPToBOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); + return Maybe::Ok(); +} + +/*static*/ Maybe EagerPToBOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe EagerPToBOp::GetSbp(user_op::SbpContext* ctx) { + return Error::TypeError() << "eager_s_to_b op doesn't support consistent tensor!"; +} + +/* static */ Maybe EagerPToBOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + return Error::TypeError() << "eager_s_to_b op doesn't support consistent tensor!"; +} + +/* static */ Maybe EagerPToBOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> EagerPToBOp::InferDevice(user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} } // namespace oneflow diff --git a/oneflow/user/ops/eager_p_to_s_op.cpp b/oneflow/user/ops/eager_p_to_s_op.cpp index b3cce498e31..0e981e21fa0 100644 --- a/oneflow/user/ops/eager_p_to_s_op.cpp +++ b/oneflow/user/ops/eager_p_to_s_op.cpp @@ -19,12 +19,11 @@ limitations under the License. #include "oneflow/core/common/shape.h" #include "oneflow/core/framework/device.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -Maybe TensorDescInfer(user_op::InferContext* ctx) { +/* static */ Maybe EagerPToSOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& shape = ctx->Attr("shape"); const std::string& out_parallel_conf_txt = ctx->Attr("out_parallel_conf"); const int64_t out_split_axis = ctx->Attr("out_split_axis"); @@ -40,27 +39,25 @@ Maybe TensorDescInfer(user_op::InferContext* ctx) { return Maybe::Ok(); } -} // namespace +/*static*/ Maybe EagerPToSOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe EagerPToSOp::GetSbp(user_op::SbpContext* ctx) { + return Error::TypeError() << "eager_b_to_s op doesn't support consistent tensor!"; +} + +/* static */ Maybe EagerPToSOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + return Error::TypeError() << "eager_b_to_s op doesn't support consistent tensor!"; +} -// Can only be called in mirrored -REGISTER_NO_GRAD_USER_OP("eager_p_to_s") - .Input("in") - .Output("out") - .Attr("out_split_axis", -1) - .Attr("in_parallel_conf") - .Attr("out_parallel_conf") - .Attr("shape") - .SetTensorDescInferFn(TensorDescInfer) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - return Error::TypeError() << "eager_b_to_s op doesn't support consistent tensor!"; - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - return Error::TypeError() << "eager_b_to_s op doesn't support consistent tensor!"; - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/* static */ Maybe EagerPToSOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> EagerPToSOp::InferDevice(user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} } // namespace oneflow diff --git a/oneflow/user/ops/eager_s_to_b_op.cpp b/oneflow/user/ops/eager_s_to_b_op.cpp index 6407f4b1ebb..3af6f00b4ad 100644 --- a/oneflow/user/ops/eager_s_to_b_op.cpp +++ b/oneflow/user/ops/eager_s_to_b_op.cpp @@ -19,31 +19,34 @@ limitations under the License. #include "oneflow/core/common/shape.h" #include "oneflow/core/framework/device.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -// Can only be called in mirrored -REGISTER_NO_GRAD_USER_OP("eager_s_to_b") - .Input("in") - .Output("out") - .Attr("in_split_axis", -1) - .Attr("in_parallel_conf") - .Attr("out_parallel_conf") - .Attr("shape") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - return Error::TypeError() << "eager_s_to_b op doesn't support consistent tensor!"; - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - return Error::TypeError() << "eager_s_to_b op doesn't support consistent tensor!"; - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/* static */ Maybe EagerSToBOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); + return Maybe::Ok(); +} + +/*static*/ Maybe EagerSToBOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe EagerSToBOp::GetSbp(user_op::SbpContext* ctx) { + return Error::TypeError() << "eager_s_to_b op doesn't support consistent tensor!"; +} + +/* static */ Maybe EagerSToBOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + return Error::TypeError() << "eager_s_to_b op doesn't support consistent tensor!"; +} + +/* static */ Maybe EagerSToBOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> EagerSToBOp::InferDevice(user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} } // namespace oneflow diff --git a/oneflow/user/ops/eager_s_to_s_op.cpp b/oneflow/user/ops/eager_s_to_s_op.cpp index 9a72361c135..773b671fcd9 100644 --- a/oneflow/user/ops/eager_s_to_s_op.cpp +++ b/oneflow/user/ops/eager_s_to_s_op.cpp @@ -19,12 +19,11 @@ limitations under the License. #include "oneflow/core/common/shape.h" #include "oneflow/core/framework/device.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -Maybe TensorDescInfer(user_op::InferContext* ctx) { +/* static */ Maybe EagerNaiveSToSOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& shape = ctx->Attr("shape"); const std::string& out_parallel_conf_txt = ctx->Attr("out_parallel_conf"); const int64_t out_split_axis = ctx->Attr("out_split_axis"); @@ -40,28 +39,25 @@ Maybe TensorDescInfer(user_op::InferContext* ctx) { return Maybe::Ok(); } -} // namespace +/*static*/ Maybe EagerNaiveSToSOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe EagerNaiveSToSOp::GetSbp(user_op::SbpContext* ctx) { + return Error::TypeError() << "eager_naive_s_to_s op doesn't support consistent tensor!"; +} + +/* static */ Maybe EagerNaiveSToSOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + return Error::TypeError() << "eager_naive_s_to_s op doesn't support consistent tensor!"; +} -// Can only be called in mirrored -REGISTER_NO_GRAD_USER_OP("eager_naive_s_to_s") - .Input("in") - .Output("out") - .Attr("in_split_axis", -1) - .Attr("out_split_axis", -1) - .Attr("in_parallel_conf") - .Attr("out_parallel_conf") - .Attr("shape") - .SetTensorDescInferFn(TensorDescInfer) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - return Error::TypeError() << "eager_naive_s_to_s op doesn't support consistent tensor!"; - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - return Error::TypeError() << "eager_naive_s_to_s op doesn't support consistent tensor!"; - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/* static */ Maybe EagerNaiveSToSOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> EagerNaiveSToSOp::InferDevice(user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} } // namespace oneflow diff --git a/oneflow/user/ops/eager_symmetric_s_to_p_op.cpp b/oneflow/user/ops/eager_symmetric_s_to_p_op.cpp index 847ee7272a9..c108a33b8cb 100644 --- a/oneflow/user/ops/eager_symmetric_s_to_p_op.cpp +++ b/oneflow/user/ops/eager_symmetric_s_to_p_op.cpp @@ -17,54 +17,61 @@ limitations under the License. #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/framework/device.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("eager_symmetric_s_to_p") - .Input("in") - .Output("out") - .Attr("in_split_axis", -1) - .Attr("parallel_conf") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const int64_t in_split_axis = ctx->user_op_conf().attr("in_split_axis"); - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - cfg::NdSbp* in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); - CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); - for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { - CHECK_OR_RETURN(sbp_hint.has_split_parallel()); - CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), in_split_axis); - } +/* static */ Maybe EagerSymmetricSToPOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} - in_nd_sbp->clear_sbp_parallel(); - out_nd_sbp->clear_sbp_parallel(); +/*static*/ Maybe EagerSymmetricSToPOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); - for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { - in_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(in_split_axis); - out_nd_sbp->add_sbp_parallel()->mutable_partial_sum_parallel(); - } - return Maybe::Ok(); - }); +/* static */ Maybe EagerSymmetricSToPOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe EagerSymmetricSToPOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + const int64_t in_split_axis = ctx->user_op_conf().attr("in_split_axis"); + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + cfg::NdSbp* in_nd_sbp = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_nd_sbp = ctx->NdSbp4ArgNameAndIndex("out", 0); + CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); + for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { + CHECK_OR_RETURN(sbp_hint.has_split_parallel()); + CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), in_split_axis); + } + + in_nd_sbp->clear_sbp_parallel(); + out_nd_sbp->clear_sbp_parallel(); + + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); + for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { + in_nd_sbp->add_sbp_parallel()->mutable_split_parallel()->set_axis(in_split_axis); + out_nd_sbp->add_sbp_parallel()->mutable_partial_sum_parallel(); + } + return Maybe::Ok(); +} + +/* static */ Maybe EagerSymmetricSToPOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> EagerSymmetricSToPOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} } // namespace oneflow diff --git a/oneflow/user/ops/elementwise_maximum_minimum_ops.cpp b/oneflow/user/ops/elementwise_maximum_minimum_ops.cpp index 4bce761d226..7a143bb4ecd 100644 --- a/oneflow/user/ops/elementwise_maximum_minimum_ops.cpp +++ b/oneflow/user/ops/elementwise_maximum_minimum_ops.cpp @@ -14,13 +14,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { using namespace user_op; -Maybe GetSbpSignature(SbpContext* ctx) { +Maybe GetSbpSignature_(SbpContext* ctx) { const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); const Shape& y_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0).shape(); @@ -35,7 +36,7 @@ Maybe GetSbpSignature(SbpContext* ctx) { return Maybe::Ok(); } -Maybe InferTensorDesc(InferContext* ctx) { +Maybe InferTensorDesc_(InferContext* ctx) { const TensorDesc& tensor_x = ctx->InputTensorDesc("x", 0); const TensorDesc& tensor_y = ctx->InputTensorDesc("y", 0); @@ -56,7 +57,7 @@ Maybe InferTensorDesc(InferContext* ctx) { return Maybe::Ok(); } -Maybe InferDataType(InferContext* ctx) { +Maybe InferDataType_(InferContext* ctx) { const TensorDesc& tensor_dz = ctx->InputTensorDesc("dz", 0); TensorDesc* tensor_dx = ctx->OutputTensorDesc("dx", 0); TensorDesc* tensor_dy = ctx->OutputTensorDesc("dy", 0); @@ -101,36 +102,55 @@ user_op::BackwardOpConfGenFn MakeGenBackwardOpFn(const std::string& op_type_name } // namespace -#define REGISTER_ELEMENTWISE_XIMUM_FW_OP(op_type_name) \ - REGISTER_USER_OP(op_type_name) \ - .Input("x") \ - .Input("y") \ - .Output("z") \ - .SetTensorDescInferFn(user_op::TensorDescInferFnUtil::Unchanged) \ - .SetGetSbpFn(user_op::GetSbpFnUtil::SplitForEachAxis) \ - .SetDataTypeInferFn(user_op::TensorDescInferFnUtil::UnchangedDataType) - -#define REGISTER_ELEMENTWISE_XIMUM_BW_OP(op_type_name) \ - REGISTER_USER_OP(op_type_name) \ - .Input("dz") \ - .Input("x") \ - .Input("y") \ - .OptionalOutput("dx") \ - .OptionalOutput("dy") \ - .SetTensorDescInferFn(InferTensorDesc) \ - .SetGetSbpFn(GetSbpSignature) \ - .SetDataTypeInferFn(InferDataType) +#define DEF_ELEMENTWISE_XIMUM_FW_OP(op_class_name_prefix) \ + /* static */ Maybe op_class_name_prefix##Op::InferLogicalTensorDesc( \ + user_op::InferContext* ctx) { \ + return user_op::TensorDescInferFnUtil::Unchanged(ctx); \ + } \ + \ + /*static*/ Maybe op_class_name_prefix##Op::InferPhysicalTensorDesc( \ + user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + \ + /* static */ Maybe op_class_name_prefix##Op::GetSbp(user_op::SbpContext* ctx) { \ + return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); \ + } \ + \ + /* static */ Maybe op_class_name_prefix##Op::InferDataType(user_op::InferContext* ctx) { \ + return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx); \ + } + +#define DEF_ELEMENTWISE_XIMUM_BW_OP(op_class_name_prefix) \ + /* static */ Maybe op_class_name_prefix##BackwardOp::InferLogicalTensorDesc( \ + user_op::InferContext* ctx) { \ + return InferTensorDesc_(ctx); \ + } \ + \ + /*static*/ Maybe op_class_name_prefix##BackwardOp::InferPhysicalTensorDesc( \ + user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + \ + /* static */ Maybe op_class_name_prefix##BackwardOp::GetSbp(user_op::SbpContext* ctx) { \ + return GetSbpSignature_(ctx); \ + } \ + \ + /* static */ Maybe op_class_name_prefix##BackwardOp::InferDataType( \ + user_op::InferContext* ctx) { \ + return InferDataType_(ctx); \ + } #define REGISTER_ELEMENTWISE_XIMUM_GRAD(op_type_name) \ REGISTER_USER_OP_GRAD(op_type_name) \ .SetBackwardOpConfGenFn(MakeGenBackwardOpFn(std::string(op_type_name))); -#define REGISTER_ELEMENTWISE_XIMUM_OP(op_type_name) \ - REGISTER_ELEMENTWISE_XIMUM_FW_OP(op_type_name); \ - REGISTER_ELEMENTWISE_XIMUM_BW_OP(op_type_name "_backward"); \ +#define REGISTER_ELEMENTWISE_XIMUM_OP(op_type_name, op_class_name_prefix) \ + DEF_ELEMENTWISE_XIMUM_FW_OP(op_class_name_prefix); \ + DEF_ELEMENTWISE_XIMUM_BW_OP(op_class_name_prefix); \ REGISTER_ELEMENTWISE_XIMUM_GRAD(op_type_name); -REGISTER_ELEMENTWISE_XIMUM_OP("elementwise_maximum"); -REGISTER_ELEMENTWISE_XIMUM_OP("elementwise_minimum"); +REGISTER_ELEMENTWISE_XIMUM_OP("elementwise_maximum", ElementwiseMaximum); +REGISTER_ELEMENTWISE_XIMUM_OP("elementwise_minimum", ElementwiseMinimum); } // namespace oneflow diff --git a/oneflow/user/ops/elu_op.cpp b/oneflow/user/ops/elu_op.cpp index 13cf0de77ec..9de85d34655 100644 --- a/oneflow/user/ops/elu_op.cpp +++ b/oneflow/user/ops/elu_op.cpp @@ -14,63 +14,62 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/* static */ Maybe EluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("elu") - .Input("in") - .Attr("alpha") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe EluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} -REGISTER_USER_OP("elu_grad") - .Input("x") - .Input("dy") - .Attr("alpha") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(dy_shape == x_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/* static */ Maybe EluOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe EluOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe EluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(dy_shape == x_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} + +/*static*/ Maybe EluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe EluGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe EluGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("elu").SetBackwardOpConfGenFn( [](user_op::BackwardOpConfContext* ctx) -> Maybe { @@ -90,6 +89,4 @@ REGISTER_USER_OP_GRAD("elu").SetBackwardOpConfGenFn( return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/empty_op.cpp b/oneflow/user/ops/empty_op.cpp index ff8e3c45c6d..8c8c24f68d8 100644 --- a/oneflow/user/ops/empty_op.cpp +++ b/oneflow/user/ops/empty_op.cpp @@ -15,45 +15,45 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/balanced_splitter.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("empty") - .Output("out") - .SetOutputBufferNum(1) - .Attr("dtype") - .Attr("shape") - .Attr>("nd_sbp") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); - return Maybe::Ok(); - }) - .SetPhysicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& shape = ctx->Attr("shape"); - DimVector dim_vec{shape.dim_vec()}; - - const cfg::SbpParallel& out_sbp_para = ctx->SbpParallel4ArgNameAndIndex("out", 0); - if (out_sbp_para.has_split_parallel()) { - const int64_t& parallel_num = ctx->parallel_ctx().parallel_num(); - if (parallel_num > 1) { - const int64_t& split_axis = out_sbp_para.split_parallel().axis(); - CHECK_LT_OR_RETURN(split_axis, dim_vec.size()); - BalancedSplitter bs(shape.At(split_axis), parallel_num); - dim_vec[split_axis] = bs.At(ctx->parallel_ctx().parallel_id()).size(); - } - } - - *ctx->OutputShape("out", 0) = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { return Maybe::Ok(); }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - cfg::SbpParallel default_sbp; - default_sbp.mutable_broadcast_parallel(); - return user_op::InferNdSbp4SrcOp(ctx, default_sbp); - }); +/* static */ Maybe EmptyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = Shape(ctx->Attr("shape").dim_vec()); + return Maybe::Ok(); +} + +/* static */ Maybe EmptyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + const Shape& shape = ctx->Attr("shape"); + DimVector dim_vec{shape.dim_vec()}; + + const cfg::SbpParallel& out_sbp_para = ctx->SbpParallel4ArgNameAndIndex("out", 0); + if (out_sbp_para.has_split_parallel()) { + const int64_t& parallel_num = ctx->parallel_ctx().parallel_num(); + if (parallel_num > 1) { + const int64_t& split_axis = out_sbp_para.split_parallel().axis(); + CHECK_LT_OR_RETURN(split_axis, dim_vec.size()); + BalancedSplitter bs(shape.At(split_axis), parallel_num); + dim_vec[split_axis] = bs.At(ctx->parallel_ctx().parallel_id()).size(); + } + } + + *ctx->OutputShape("out", 0) = Shape(dim_vec); + return Maybe::Ok(); +} + +/* static */ Maybe EmptyOp::GetSbp(user_op::SbpContext* ctx) { return Maybe::Ok(); } + +/* static */ Maybe EmptyOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + cfg::SbpParallel default_sbp; + default_sbp.mutable_broadcast_parallel(); + return user_op::InferNdSbp4SrcOp(ctx, default_sbp); +} + +/* static */ Maybe EmptyOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); + return Maybe::Ok(); +} + } // namespace oneflow diff --git a/oneflow/user/ops/expand_dims_op.cpp b/oneflow/user/ops/expand_dims_op.cpp index 99c1be0f79b..f5031f7a1b3 100644 --- a/oneflow/user/ops/expand_dims_op.cpp +++ b/oneflow/user/ops/expand_dims_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -28,43 +29,45 @@ int32_t TransformNegativeAxisToPositive(int32_t axis, const int32_t num_axes) { } // namespace -REGISTER_USER_OP("expand_dims") - .Input("in") - .Output("out") - .Attr("axis") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); - const int32_t axis = - TransformNegativeAxisToPositive(ctx->Attr("axis"), in_shape.NumAxes()); +/* static */ Maybe ExpandDimsOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + Shape* out_shape = ctx->OutputShape("out", 0); + const int32_t axis = + TransformNegativeAxisToPositive(ctx->Attr("axis"), in_shape.NumAxes()); - auto dim_vec = in_shape.dim_vec(); - dim_vec.insert(dim_vec.begin() + axis, 1); - *out_shape = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - const int32_t axis = - TransformNegativeAxisToPositive(ctx->Attr("axis"), in_tensor.shape().NumAxes()); + auto dim_vec = in_shape.dim_vec(); + dim_vec.insert(dim_vec.begin() + axis, 1); + *out_shape = Shape(dim_vec); + return Maybe::Ok(); +} - auto dim_vec = in_tensor.shape().dim_vec(); - FOR_RANGE(int32_t, in_axis, 0, dim_vec.size()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), in_axis) - .Split(user_op::OpArg("out", 0), in_axis < axis ? in_axis : in_axis + 1) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ExpandDimsOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ExpandDimsOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + const int32_t axis = + TransformNegativeAxisToPositive(ctx->Attr("axis"), in_tensor.shape().NumAxes()); + + auto dim_vec = in_tensor.shape().dim_vec(); + FOR_RANGE(int32_t, in_axis, 0, dim_vec.size()) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), in_axis) + .Split(user_op::OpArg("out", 0), in_axis < axis ? in_axis : in_axis + 1) + .Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ExpandDimsOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("expand_dims") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/expand_op.cpp b/oneflow/user/ops/expand_op.cpp index 4ac73550bf9..9e8cfd5c2ef 100644 --- a/oneflow/user/ops/expand_op.cpp +++ b/oneflow/user/ops/expand_op.cpp @@ -15,116 +15,119 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/expand_kernel_utils.h" +#include "oneflow/core/framework/op_generated.h" + namespace oneflow { -REGISTER_USER_OP("expand") - .Input("in") - .Output("out") - .Attr>("logical_in_shape") - .Attr>("logical_expand_shape") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& input_shape = ctx->InputShape("in", 0); - const std::vector& logical_expand_shape = - ctx->Attr>("logical_expand_shape"); - - std::vector in_shape; - in_shape.resize(input_shape.NumAxes()); - for (int i = 0; i < input_shape.NumAxes(); ++i) { in_shape[i] = input_shape.At(i); } - - std::vector out_shape; - std::vector stride; - CHECK_JUST(getOutShapeAndStrideForFp(in_shape, logical_expand_shape, out_shape, stride)); - - Shape* output_shape = ctx->OutputShape("out", 0); - DimVector dim_vec(out_shape.begin(), out_shape.end()); - *output_shape = Shape(dim_vec); - - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const std::vector& logical_in_shape = - ctx->Attr>("logical_in_shape"); - const std::vector& logical_expand_shape = - ctx->Attr>("logical_expand_shape"); - std::vector logical_out_shape; - std::vector stride; - CHECK_JUST( - getOutShapeAndStride(logical_in_shape, logical_expand_shape, logical_out_shape, stride)); - - int offset = logical_out_shape.size() - logical_in_shape.size(); - FOR_RANGE(int64_t, i, 0, logical_in_shape.size()) { - if (logical_in_shape[i] == logical_out_shape[i + offset]) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i + offset) - .Build(); - } - } +/* static */ Maybe ExpandOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& input_shape = ctx->InputShape("in", 0); + const std::vector& logical_expand_shape = + ctx->Attr>("logical_expand_shape"); + + std::vector in_shape; + in_shape.resize(input_shape.NumAxes()); + for (int i = 0; i < input_shape.NumAxes(); ++i) { in_shape[i] = input_shape.At(i); } + + std::vector out_shape; + std::vector stride; + CHECK_JUST(getOutShapeAndStrideForFp(in_shape, logical_expand_shape, out_shape, stride)); + + Shape* output_shape = ctx->OutputShape("out", 0); + DimVector dim_vec(out_shape.begin(), out_shape.end()); + *output_shape = Shape(dim_vec); + return Maybe::Ok(); +} + +/*static*/ Maybe ExpandOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ExpandOp::GetSbp(user_op::SbpContext* ctx) { + const std::vector& logical_in_shape = + ctx->Attr>("logical_in_shape"); + const std::vector& logical_expand_shape = + ctx->Attr>("logical_expand_shape"); + std::vector logical_out_shape; + std::vector stride; + CHECK_JUST( + getOutShapeAndStride(logical_in_shape, logical_expand_shape, logical_out_shape, stride)); + + int offset = logical_out_shape.size() - logical_in_shape.size(); + FOR_RANGE(int64_t, i, 0, logical_in_shape.size()) { + if (logical_in_shape[i] == logical_out_shape[i + offset]) { ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) + .Split(user_op::OpArg("in", 0), i) + .Split(user_op::OpArg("out", 0), i + offset) .Build(); - return Maybe::Ok(); - }); - -REGISTER_USER_OP("expand_grad") - .Input("in") - .Output("out") - .Attr>("logical_out_shape") - .Attr>("logical_expand_shape") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& input_shape = ctx->InputShape("in", 0); - const std::vector& logical_out_shape = - ctx->Attr>("logical_out_shape"); - const std::vector& logical_expand_shape = - ctx->Attr>("logical_expand_shape"); - - std::vector in_shape; - in_shape.resize(input_shape.NumAxes()); - for (int i = 0; i < input_shape.NumAxes(); ++i) { in_shape[i] = input_shape.At(i); } - std::vector out_shape; - std::vector stride; - CHECK_JUST(getOutShapeAndStrideForBp(logical_out_shape, logical_expand_shape, in_shape, - out_shape, stride)); - - Shape* output_shape = ctx->OutputShape("out", 0); - DimVector dim_vec(out_shape.begin(), out_shape.end()); - *output_shape = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& input_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - const std::vector& logical_out_shape = - ctx->Attr>("logical_out_shape"); - const std::vector& logical_expand_shape = - ctx->Attr>("logical_expand_shape"); - - int offset = input_tensor.shape().NumAxes() - logical_out_shape.size(); - FOR_RANGE(int64_t, i, 0, logical_out_shape.size()) { - if (logical_out_shape[i] == input_tensor.shape().At(i + offset)) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i + offset) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - } + } + } + + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} +/* static */ Maybe ExpandOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe ExpandGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& input_shape = ctx->InputShape("in", 0); + const std::vector& logical_out_shape = + ctx->Attr>("logical_out_shape"); + const std::vector& logical_expand_shape = + ctx->Attr>("logical_expand_shape"); + + std::vector in_shape; + in_shape.resize(input_shape.NumAxes()); + for (int i = 0; i < input_shape.NumAxes(); ++i) { in_shape[i] = input_shape.At(i); } + std::vector out_shape; + std::vector stride; + CHECK_JUST(getOutShapeAndStrideForBp(logical_out_shape, logical_expand_shape, in_shape, out_shape, + stride)); + + Shape* output_shape = ctx->OutputShape("out", 0); + DimVector dim_vec(out_shape.begin(), out_shape.end()); + *output_shape = Shape(dim_vec); + return Maybe::Ok(); +} + +/*static*/ Maybe ExpandGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ExpandGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& input_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + const std::vector& logical_out_shape = + ctx->Attr>("logical_out_shape"); + const std::vector& logical_expand_shape = + ctx->Attr>("logical_expand_shape"); + + int offset = input_tensor.shape().NumAxes() - logical_out_shape.size(); + FOR_RANGE(int64_t, i, 0, logical_out_shape.size()) { + if (logical_out_shape[i] == input_tensor.shape().At(i + offset)) { ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) + .Split(user_op::OpArg("in", 0), i + offset) + .Split(user_op::OpArg("out", 0), i) .Build(); - return Maybe::Ok(); - }); + } + } + + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ExpandGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("expand").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/eye_op.cpp b/oneflow/user/ops/eye_op.cpp index 80e0465b8a4..077758b2452 100644 --- a/oneflow/user/ops/eye_op.cpp +++ b/oneflow/user/ops/eye_op.cpp @@ -14,27 +14,29 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("eye") - .Output("out") - .Attr("rows") - .Attr("cols") - .Attr("dtype") - .Attr>("nd_sbp") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - int64_t rows = ctx->Attr("rows"); - int64_t cols = ctx->Attr("cols"); - *ctx->OutputShape("out", 0) = Shape({rows, cols}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); - return Maybe::Ok(); - }); + +/* static */ Maybe EyeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + int64_t rows = ctx->Attr("rows"); + int64_t cols = ctx->Attr("cols"); + *ctx->OutputShape("out", 0) = Shape({rows, cols}); + return Maybe::Ok(); +} + +/*static*/ Maybe EyeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe EyeOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe EyeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/fake_quantization_op.cpp b/oneflow/user/ops/fake_quantization_op.cpp index d8fa0242f96..fbe6a7d8ca6 100644 --- a/oneflow/user/ops/fake_quantization_op.cpp +++ b/oneflow/user/ops/fake_quantization_op.cpp @@ -14,104 +14,99 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -REGISTER_USER_OP("fake_quantization") - .Input("in") - .Input("scale") - .Input("zero_point") - .Output("out") - // NOTE(Liang Depeng): "google" or "cambricon" - .Attr("quantization_formula", "google") - // NOTE(Liang Depeng): quantize from float32 to "quantization_bit" bit signed or unsigned - // integer - .Attr("quantization_bit", 8) - // NOTE(Liang Depeng): "symmetric" or "affine": quantize to signed or unsigned integer - .Attr("quantization_scheme", "symmetric") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - const Shape& scale_shape = ctx->InputShape("scale", 0); - const Shape& zero_point_shape = ctx->InputShape("zero_point", 0); - - // NOTE(Liang Depeng): scale_shape->elem_cnt() > 1 means per-channel quantization for - // convolution weights. - if (scale_shape.elem_cnt() > 1) { - CHECK_EQ_OR_RETURN(scale_shape.elem_cnt(), in_shape.At(0)); - CHECK_EQ_OR_RETURN(zero_point_shape.elem_cnt(), in_shape.At(0)); - } +/* static */ Maybe FakeQuantizationOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + const Shape& scale_shape = ctx->InputShape("scale", 0); + const Shape& zero_point_shape = ctx->InputShape("zero_point", 0); - *ctx->OutputShape("out", 0) = in_shape; - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* scale = GetInputArgModifierFn("scale", 0); - CHECK_OR_RETURN(scale != nullptr); - scale->set_requires_grad(false); - - user_op::InputArgModifier* zero_point = GetInputArgModifierFn("zero_point", 0); - CHECK_OR_RETURN(zero_point != nullptr); - zero_point->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - const Shape& logical_scale_shape = - ctx->LogicalTensorDesc4InputArgNameAndIndex("scale", 0).shape(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("in", 0)) - .Broadcast(user_op::OpArg("scale", 0)) - .Broadcast(user_op::OpArg("zero_point", 0)) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - if (logical_scale_shape.elem_cnt() > 1) { - // NOTE(Liang Depeng): only consider convolution weight per-channel quantization - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), 0) - .Split(user_op::OpArg("scale", 0), 0) - .Split(user_op::OpArg("zero_point", 0), 0) - .Split(user_op::OpArg("out", 0), 0) - .Build(); - } else { - // NOTE(Liang Depeng): the sbp signature of per-layer quantization is the same as eltwise - // ops - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), 0) - .Broadcast(user_op::OpArg("scale", 0)) - .Broadcast(user_op::OpArg("zero_point", 0)) - .Split(user_op::OpArg("out", 0), 0) - .Build(); - } - FOR_RANGE(int64_t, i, 1, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Broadcast(user_op::OpArg("scale", 0)) - .Broadcast(user_op::OpArg("zero_point", 0)) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& op_def, - const user_op::UserOpConfWrapper& op_conf) -> Maybe { - const int32_t quantization_bit = op_conf.attr("quantization_bit"); - CHECK_GT_OR_RETURN(quantization_bit, 1); - CHECK_LE_OR_RETURN(quantization_bit, 8); - - std::string quantization_scheme = op_conf.attr("quantization_scheme"); - CHECK_OR_RETURN(quantization_scheme == "symmetric" || quantization_scheme == "affine"); - - std::string quantization_formula = op_conf.attr("quantization_formula"); - CHECK_OR_RETURN(quantization_formula == "google" || quantization_formula == "cambricon"); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); + // NOTE(Liang Depeng): scale_shape->elem_cnt() > 1 means per-channel quantization for + // convolution weights. + if (scale_shape.elem_cnt() > 1) { + CHECK_EQ_OR_RETURN(scale_shape.elem_cnt(), in_shape.At(0)); + CHECK_EQ_OR_RETURN(zero_point_shape.elem_cnt(), in_shape.At(0)); + } + + *ctx->OutputShape("out", 0) = in_shape; + return Maybe::Ok(); +} + +/*static*/ Maybe FakeQuantizationOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe FakeQuantizationOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + const Shape& logical_scale_shape = + ctx->LogicalTensorDesc4InputArgNameAndIndex("scale", 0).shape(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("in", 0)) + .Broadcast(user_op::OpArg("scale", 0)) + .Broadcast(user_op::OpArg("zero_point", 0)) + .Broadcast(user_op::OpArg("out", 0)) + .Build(); + if (logical_scale_shape.elem_cnt() > 1) { + // NOTE(Liang Depeng): only consider convolution weight per-channel quantization + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), 0) + .Split(user_op::OpArg("scale", 0), 0) + .Split(user_op::OpArg("zero_point", 0), 0) + .Split(user_op::OpArg("out", 0), 0) + .Build(); + } else { + // NOTE(Liang Depeng): the sbp signature of per-layer quantization is the same as eltwise + // ops + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), 0) + .Broadcast(user_op::OpArg("scale", 0)) + .Broadcast(user_op::OpArg("zero_point", 0)) + .Split(user_op::OpArg("out", 0), 0) + .Build(); + } + FOR_RANGE(int64_t, i, 1, in_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .Broadcast(user_op::OpArg("scale", 0)) + .Broadcast(user_op::OpArg("zero_point", 0)) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe FakeQuantizationOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* scale = GetInputArgModifierFn("scale", 0); + CHECK_OR_RETURN(scale != nullptr); + scale->set_requires_grad(false); + + user_op::InputArgModifier* zero_point = GetInputArgModifierFn("zero_point", 0); + CHECK_OR_RETURN(zero_point != nullptr); + zero_point->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe FakeQuantizationOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + const int32_t quantization_bit = conf.attr("quantization_bit"); + CHECK_GT_OR_RETURN(quantization_bit, 1); + CHECK_LE_OR_RETURN(quantization_bit, 8); + + std::string quantization_scheme = conf.attr("quantization_scheme"); + CHECK_OR_RETURN(quantization_scheme == "symmetric" || quantization_scheme == "affine"); + + std::string quantization_formula = conf.attr("quantization_formula"); + CHECK_OR_RETURN(quantization_formula == "google" || quantization_formula == "cambricon"); + return Maybe::Ok(); +} + +/* static */ Maybe FakeQuantizationOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("fake_quantization") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -129,6 +124,4 @@ REGISTER_USER_OP_GRAD("fake_quantization") return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/flatten_op.cpp b/oneflow/user/ops/flatten_op.cpp index 03f0b9b2b97..487d7abc372 100644 --- a/oneflow/user/ops/flatten_op.cpp +++ b/oneflow/user/ops/flatten_op.cpp @@ -14,39 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -Maybe GetSbpFn(user_op::SbpContext* ctx) { - const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); - const int32_t start_dim = ctx->Attr("start_dim"); - const int32_t end_dim = ctx->Attr("end_dim"); - - CHECK_GE_OR_RETURN(start_dim, 0); - CHECK_LT_OR_RETURN(start_dim, in_shape.NumAxes()); - const int32_t true_end_dim = end_dim < 0 ? end_dim + in_shape.NumAxes() : end_dim; - CHECK_GE_OR_RETURN(true_end_dim, 0); - CHECK_LT_OR_RETURN(true_end_dim, in_shape.NumAxes()); - CHECK_LE_OR_RETURN(start_dim, true_end_dim); - - for (int i = 0; i <= start_dim; ++i) { - ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); - } - const int32_t diff = true_end_dim - start_dim; - for (int i = true_end_dim + 1; i < in_shape.NumAxes(); ++i) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i - diff) - .Build(); - } - - ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); - return Maybe::Ok(); -} - -Maybe TensorDescInferFn(user_op::InferContext* ctx) { +/* static */ Maybe FlattenOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const int32_t start_dim = ctx->Attr("start_dim"); const int32_t end_dim = ctx->Attr("end_dim"); const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc("in", 0); @@ -79,19 +51,41 @@ Maybe TensorDescInferFn(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe DataTypeInferFn(user_op::InferContext* ctx) { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); +/*static*/ Maybe FlattenOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe FlattenOp::GetSbp(user_op::SbpContext* ctx) { + const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); + const int32_t start_dim = ctx->Attr("start_dim"); + const int32_t end_dim = ctx->Attr("end_dim"); + + CHECK_GE_OR_RETURN(start_dim, 0); + CHECK_LT_OR_RETURN(start_dim, in_shape.NumAxes()); + const int32_t true_end_dim = end_dim < 0 ? end_dim + in_shape.NumAxes() : end_dim; + CHECK_GE_OR_RETURN(true_end_dim, 0); + CHECK_LT_OR_RETURN(true_end_dim, in_shape.NumAxes()); + CHECK_LE_OR_RETURN(start_dim, true_end_dim); + + for (int i = 0; i <= start_dim; ++i) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + const int32_t diff = true_end_dim - start_dim; + for (int i = true_end_dim + 1; i < in_shape.NumAxes(); ++i) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .Split(user_op::OpArg("out", 0), i - diff) + .Build(); + } + + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); return Maybe::Ok(); } -REGISTER_USER_OP("flatten") - .Input("in") - .Output("out") - .Attr("start_dim", 0) - .Attr("end_dim", -1) - .SetTensorDescInferFn(TensorDescInferFn) - .SetGetSbpFn(GetSbpFn) - .SetDataTypeInferFn(DataTypeInferFn); +/* static */ Maybe FlattenOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("flatten").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { @@ -109,5 +103,4 @@ REGISTER_USER_OP_GRAD("flatten").SetGenBackwardOpConfFn([](const user_op::UserOp return Maybe::Ok(); }); -} // namespace } // namespace oneflow diff --git a/oneflow/user/ops/flip_op.cpp b/oneflow/user/ops/flip_op.cpp index bb7ce9fbd8c..e062fc6b422 100644 --- a/oneflow/user/ops/flip_op.cpp +++ b/oneflow/user/ops/flip_op.cpp @@ -14,50 +14,48 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("flip") - .Input("x") - .Output("y") - .Attr>("dims") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - const int input_dims = x_desc.shape().NumAxes(); - const std::vector dims = ctx->Attr>("dims"); - CHECK_OR_RETURN(dims.size() <= input_dims) - << "len of dims must less than len of input tensor"; - for (auto x : dims) { CHECK_OR_RETURN(x < input_dims) << "dims parameter is illegal."; } - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - *y_desc->mut_shape() = x_desc.shape(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ auto FlipOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + const int input_dims = x_desc.shape().NumAxes(); + const std::vector dims = ctx->Attr>("dims"); + CHECK_OR_RETURN(dims.size() <= input_dims) << "len of dims must less than len of input tensor"; + for (auto x : dims) { CHECK_OR_RETURN(x < input_dims) << "dims parameter is illegal."; } + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + *y_desc->mut_shape() = x_desc.shape(); + return Maybe::Ok(); +} +/*static*/ auto FlipOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { + return FlipOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FlipOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ auto FlipOp::InferDataType(user_op::InferContext* ctx) -> Maybe { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("flip_grad") - .Input("dy") - .Output("dx") - .Attr>("dims") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ auto FlipGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + *dx_shape = dy_shape; + return Maybe::Ok(); +} +/*static*/ auto FlipGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { + return FlipGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FlipGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ auto FlipGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/fused_bias_add_op.cpp b/oneflow/user/ops/fused_bias_add_op.cpp index 1024b4acfdc..46f9394ff18 100644 --- a/oneflow/user/ops/fused_bias_add_op.cpp +++ b/oneflow/user/ops/fused_bias_add_op.cpp @@ -14,93 +14,93 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("fused_bias_add_gelu") - .Input("a") - .Input("b") - .Output("out") - .Attr("axis") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); - const auto& b_tensor_desc = ctx->InputTensorDesc("b", 0); - const auto bias_add_axis = ctx->Attr("axis"); - CHECK_EQ_OR_RETURN(b_tensor_desc.shape().NumAxes(), 1); - CHECK_GE_OR_RETURN(bias_add_axis, 0); - CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); - CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); - *ctx->OutputShape("out", 0) = a_tensor_desc.shape(); - *ctx->OutputIsDynamic("out", 0) = a_tensor_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); - *ctx->OutputDType("out", 0) = a_tensor_desc.data_type(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto axis = ctx->Attr("axis"); - for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape().NumAxes(); - ++i) { - if (i == axis) { continue; } - ctx->NewBuilder() - .Split(user_op::OpArg("a", 0), i) - .Broadcast(user_op::OpArg("b", 0)) - .Split(ctx->outputs(), i) - .Build(); - } - ctx->NewBuilder() - .Split(user_op::OpArg("b", 0), 0) - .Split(user_op::OpArg("a", 0), axis) - .Split(ctx->outputs(), axis) - .Build(); - return Maybe::Ok(); - }); +/*static*/ auto FusedBiasAddGeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); + const auto& b_tensor_desc = ctx->InputTensorDesc("b", 0); + const auto bias_add_axis = ctx->Attr("axis"); + CHECK_EQ_OR_RETURN(b_tensor_desc.shape().NumAxes(), 1); + CHECK_GE_OR_RETURN(bias_add_axis, 0); + CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); + CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); + *ctx->OutputShape("out", 0) = a_tensor_desc.shape(); + *ctx->OutputIsDynamic("out", 0) = a_tensor_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ auto FusedBiasAddGeluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + return FusedBiasAddGeluOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FusedBiasAddGeluOp::InferDataType(user_op::InferContext* ctx) -> Maybe { + const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); + *ctx->OutputDType("out", 0) = a_tensor_desc.data_type(); + return Maybe::Ok(); +} +/*static*/ auto FusedBiasAddGeluOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + const auto axis = ctx->Attr("axis"); + for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape().NumAxes(); + ++i) { + if (i == axis) { continue; } + ctx->NewBuilder() + .Split(user_op::OpArg("a", 0), i) + .Broadcast(user_op::OpArg("b", 0)) + .Split(ctx->outputs(), i) + .Build(); + } + ctx->NewBuilder() + .Split(user_op::OpArg("b", 0), 0) + .Split(user_op::OpArg("a", 0), axis) + .Split(ctx->outputs(), axis) + .Build(); + return Maybe::Ok(); +} +/*static*/ auto FusedBiasAddGeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); + const auto& b_tensor_desc = ctx->InputTensorDesc("b", 0); + const auto bias_add_axis = ctx->Attr("axis"); + CHECK_EQ_OR_RETURN(b_tensor_desc.shape().NumAxes(), 1); + CHECK_GE_OR_RETURN(bias_add_axis, 0); + CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); + CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); + *ctx->OutputShape("dx", 0) = a_tensor_desc.shape(); + *ctx->OutputIsDynamic("dx", 0) = a_tensor_desc.is_dynamic(); + return Maybe::Ok(); +} -REGISTER_USER_OP("fused_bias_add_gelu_grad") - .Input("a") - .Input("b") - .Input("dy") - .Output("dx") - .Attr("axis") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); - const auto& b_tensor_desc = ctx->InputTensorDesc("b", 0); - const auto bias_add_axis = ctx->Attr("axis"); - CHECK_EQ_OR_RETURN(b_tensor_desc.shape().NumAxes(), 1); - CHECK_GE_OR_RETURN(bias_add_axis, 0); - CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); - CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); - *ctx->OutputShape("dx", 0) = a_tensor_desc.shape(); - *ctx->OutputIsDynamic("dx", 0) = a_tensor_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); - *ctx->OutputDType("dx", 0) = a_tensor_desc.data_type(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto axis = ctx->Attr("axis"); - for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape().NumAxes(); - ++i) { - if (i == axis) { continue; } - ctx->NewBuilder() - .Split(user_op::OpArg("a", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Broadcast(user_op::OpArg("b", 0)) - .Split(ctx->outputs(), i) - .Build(); - } - ctx->NewBuilder() - .Split(user_op::OpArg("b", 0), 0) - .Split(user_op::OpArg("a", 0), axis) - .Split(user_op::OpArg("dy", 0), axis) - .Split(ctx->outputs(), axis) - .Build(); - return Maybe::Ok(); - }); +/*static*/ auto FusedBiasAddGeluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + return FusedBiasAddGeluGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FusedBiasAddGeluGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe { + const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); + *ctx->OutputDType("dx", 0) = a_tensor_desc.data_type(); + return Maybe::Ok(); +} +/*static*/ auto FusedBiasAddGeluGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + const auto axis = ctx->Attr("axis"); + for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape().NumAxes(); + ++i) { + if (i == axis) { continue; } + ctx->NewBuilder() + .Split(user_op::OpArg("a", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Broadcast(user_op::OpArg("b", 0)) + .Split(ctx->outputs(), i) + .Build(); + } + ctx->NewBuilder() + .Split(user_op::OpArg("b", 0), 0) + .Split(user_op::OpArg("a", 0), axis) + .Split(user_op::OpArg("dy", 0), axis) + .Split(ctx->outputs(), axis) + .Build(); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("fused_bias_add_gelu") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -141,57 +141,55 @@ REGISTER_USER_OP_GRAD("fused_bias_add_gelu") return Maybe::Ok(); }); -REGISTER_USER_OP("fused_bias_add_mask_scale") - .Input("a") - .Input("b") - .Input("mask") - .OptionalInput("_add_to_output") - .Output("out") - .Attr("axis") - .Attr("scale") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); - const auto& mask_tensor_desc = ctx->InputTensorDesc("mask", 0); - const auto& b_tensor_desc = ctx->InputTensorDesc("b", 0); - const auto bias_add_axis = ctx->Attr("axis"); - CHECK_EQ_OR_RETURN(b_tensor_desc.shape().NumAxes(), 1); - CHECK_GE_OR_RETURN(bias_add_axis, 0); - CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); - CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); - CHECK_EQ_OR_RETURN(a_tensor_desc.shape(), mask_tensor_desc.shape()); - *ctx->OutputShape("out", 0) = a_tensor_desc.shape(); - *ctx->OutputIsDynamic("out", 0) = a_tensor_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); - *ctx->OutputDType("out", 0) = a_tensor_desc.data_type(); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); - CHECK_OR_RETURN(mask_modifier != nullptr); - mask_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto axis = ctx->Attr("axis"); - std::vector split_args; - split_args.emplace_back("a", 0); - split_args.emplace_back("mask", 0); - split_args.emplace_back("out", 0); - if (ctx->user_op_conf().has_input("_add_to_output", 0)) { - split_args.emplace_back("_add_to_output", 0); - } - for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape().NumAxes(); - ++i) { - if (i == axis) { continue; } - ctx->NewBuilder().Split(split_args, i).Broadcast(user_op::OpArg("b", 0)).Build(); - } - ctx->NewBuilder().Split(user_op::OpArg("b", 0), 0).Split(split_args, axis).Build(); - return Maybe::Ok(); - }); +/*static*/ auto FusedBiasAddMaskScaleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); + const auto& mask_tensor_desc = ctx->InputTensorDesc("mask", 0); + const auto& b_tensor_desc = ctx->InputTensorDesc("b", 0); + const auto bias_add_axis = ctx->Attr("axis"); + CHECK_EQ_OR_RETURN(b_tensor_desc.shape().NumAxes(), 1); + CHECK_GE_OR_RETURN(bias_add_axis, 0); + CHECK_LT_OR_RETURN(bias_add_axis, a_tensor_desc.shape().NumAxes()); + CHECK_EQ_OR_RETURN(a_tensor_desc.shape().At(bias_add_axis), b_tensor_desc.shape().At(0)); + CHECK_EQ_OR_RETURN(a_tensor_desc.shape(), mask_tensor_desc.shape()); + *ctx->OutputShape("out", 0) = a_tensor_desc.shape(); + *ctx->OutputIsDynamic("out", 0) = a_tensor_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ auto FusedBiasAddMaskScaleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + return FusedBiasAddMaskScaleOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FusedBiasAddMaskScaleOp::InferDataType(user_op::InferContext* ctx) -> Maybe { + const auto& a_tensor_desc = ctx->InputTensorDesc("a", 0); + *ctx->OutputDType("out", 0) = a_tensor_desc.data_type(); + return Maybe::Ok(); +} +/*static*/ auto FusedBiasAddMaskScaleOp::ModifyInputArg( + const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) + -> Maybe { + user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); + CHECK_OR_RETURN(mask_modifier != nullptr); + mask_modifier->set_requires_grad(false); + return Maybe::Ok(); +} +/*static*/ auto FusedBiasAddMaskScaleOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + const auto axis = ctx->Attr("axis"); + std::vector split_args; + split_args.emplace_back("a", 0); + split_args.emplace_back("mask", 0); + split_args.emplace_back("out", 0); + if (ctx->user_op_conf().has_input("_add_to_output", 0)) { + split_args.emplace_back("_add_to_output", 0); + } + for (int64_t i = 0; i < ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape().NumAxes(); + ++i) { + if (i == axis) { continue; } + ctx->NewBuilder().Split(split_args, i).Broadcast(user_op::OpArg("b", 0)).Build(); + } + ctx->NewBuilder().Split(user_op::OpArg("b", 0), 0).Split(split_args, axis).Build(); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("fused_bias_add_mask_scale") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/fused_cast_scale_op.cpp b/oneflow/user/ops/fused_cast_scale_op.cpp index e09f91c9375..816a10efb06 100644 --- a/oneflow/user/ops/fused_cast_scale_op.cpp +++ b/oneflow/user/ops/fused_cast_scale_op.cpp @@ -14,11 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { -Maybe TensorDescInfer(user_op::InferContext* ctx) { +Maybe FusedCastScaleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); const user_op::TensorDesc& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); CHECK_EQ_OR_RETURN(scale_by_tensor.shape().NumAxes(), 1); @@ -29,14 +29,18 @@ Maybe TensorDescInfer(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe DataTypeInfer(user_op::InferContext* ctx) { +Maybe FusedCastScaleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return FusedCastScaleOp::InferLogicalTensorDesc(ctx); +} + +Maybe FusedCastScaleOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& scale_by_tensor = ctx->InputTensorDesc("scale_by_tensor", 0); user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); *y->mut_data_type() = scale_by_tensor.data_type(); return Maybe::Ok(); } -Maybe GetSbpSignatures(user_op::SbpContext* ctx) { +Maybe FusedCastScaleOp::GetSbp(user_op::SbpContext* ctx) { const auto& x = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); for (int i = 0; i < x.shape().NumAxes(); ++i) { ctx->NewBuilder() @@ -58,14 +62,4 @@ Maybe GetSbpSignatures(user_op::SbpContext* ctx) { return Maybe::Ok(); } -REGISTER_USER_OP("fused_cast_scale") - .Input("x") - .Input("scale_by_tensor") - .Output("y") - .Attr("scale", 1.0) - .SetTensorDescInferFn(TensorDescInfer) - .SetGetSbpFn(GetSbpSignatures) - .SetDataTypeInferFn(DataTypeInfer); - -} // namespace } // namespace oneflow diff --git a/oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp b/oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp index 32bf8db067d..028d218359e 100644 --- a/oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp +++ b/oneflow/user/ops/fused_scale_mask_softmax_dropout_op.cpp @@ -14,106 +14,102 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ auto FusedScaleMaskSoftmaxDropoutOp::InferLogicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); + CHECK_OR_RETURN(x_desc.shape() == mask_desc.shape()); + *ctx->OutputShape("y", 0) = x_desc.shape(); + *ctx->OutputIsDynamic("y", 0) = x_desc.is_dynamic(); + *ctx->OutputShape("softmax_y", 0) = x_desc.shape(); + *ctx->OutputIsDynamic("softmax_y", 0) = x_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ auto FusedScaleMaskSoftmaxDropoutOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + return FusedScaleMaskSoftmaxDropoutOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FusedScaleMaskSoftmaxDropoutOp::InferDataType(user_op::InferContext* ctx) + -> Maybe { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); + CHECK_OR_RETURN(mask_desc.data_type() == DataType::kInt8); + *ctx->OutputDType("y", 0) = x_desc.data_type(); + *ctx->OutputDType("softmax_y", 0) = x_desc.data_type(); + return Maybe::Ok(); +} +/*static*/ auto FusedScaleMaskSoftmaxDropoutOp::ModifyInputArg( + const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) + -> Maybe { + user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); + user_op::InputArgModifier* dropout_mask_modifier = GetInputArgModifierFn("dropout_mask", 0); + CHECK_OR_RETURN(mask_modifier != nullptr); + CHECK_OR_RETURN(dropout_mask_modifier != nullptr); + mask_modifier->set_requires_grad(false); + dropout_mask_modifier->set_requires_grad(false); + return Maybe::Ok(); +} +/*static*/ auto FusedScaleMaskSoftmaxDropoutOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + CHECK_GE_OR_RETURN(x_tensor.shape().NumAxes(), 2); + FOR_RANGE(int64_t, axis, 0, x_tensor.shape().NumAxes() - 2) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), axis) + .Split(user_op::OpArg("mask", 0), axis) + .Split(user_op::OpArg("dropout_mask", 0), axis) + .Split(user_op::OpArg("y", 0), axis) + .Split(user_op::OpArg("softmax_y", 0), axis) + .Build(); + } + return Maybe::Ok(); +} -REGISTER_USER_OP("fused_scale_mask_softmax_dropout") - .Input("x") - .Input("mask") - .Input("dropout_mask") - .Output("y") - .Output("softmax_y") - .Attr("scale_value", 1.0) - .Attr("mask_fill_value", 0.) - .Attr("dropout_scale_value", 1.0) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); - CHECK_OR_RETURN(x_desc.shape() == mask_desc.shape()); - *ctx->OutputShape("y", 0) = x_desc.shape(); - *ctx->OutputIsDynamic("y", 0) = x_desc.is_dynamic(); - *ctx->OutputShape("softmax_y", 0) = x_desc.shape(); - *ctx->OutputIsDynamic("softmax_y", 0) = x_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); - CHECK_OR_RETURN(mask_desc.data_type() == DataType::kInt8); - *ctx->OutputDType("y", 0) = x_desc.data_type(); - *ctx->OutputDType("softmax_y", 0) = x_desc.data_type(); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); - user_op::InputArgModifier* dropout_mask_modifier = GetInputArgModifierFn("dropout_mask", 0); - CHECK_OR_RETURN(mask_modifier != nullptr); - CHECK_OR_RETURN(dropout_mask_modifier != nullptr); - mask_modifier->set_requires_grad(false); - dropout_mask_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - CHECK_GE_OR_RETURN(x_tensor.shape().NumAxes(), 2); - FOR_RANGE(int64_t, axis, 0, x_tensor.shape().NumAxes() - 2) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), axis) - .Split(user_op::OpArg("mask", 0), axis) - .Split(user_op::OpArg("dropout_mask", 0), axis) - .Split(user_op::OpArg("y", 0), axis) - .Split(user_op::OpArg("softmax_y", 0), axis) - .Build(); - } - return Maybe::Ok(); - }); - -REGISTER_USER_OP("fused_scale_mask_softmax_dropout_grad") - .Input("softmax_y") - .Input("dy") - .Input("mask") - .Input("dropout_mask") - .Output("dx") - .Attr("scale_value") - .Attr("dropout_scale_value") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc("softmax_y", 0); - const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); - const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); - CHECK_EQ_OR_RETURN(dy_desc.shape(), softmax_y_desc.shape()); - CHECK_OR_RETURN(dy_desc.shape() == mask_desc.shape()); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); - *dx_desc->mut_shape() = dy_desc.shape(); - *dx_desc->mut_is_dynamic() = dy_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc("softmax_y", 0); - const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); - const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); - CHECK_OR_RETURN(dy_desc.data_type() == softmax_y_desc.data_type()); - CHECK_OR_RETURN(mask_desc.data_type() == DataType::kInt8); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); - *dx_desc->mut_data_type() = dy_desc.data_type(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); - CHECK_GE_OR_RETURN(dy_tensor.shape().NumAxes(), 2); - FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes() - 2) { - ctx->NewBuilder() - .Split(user_op::OpArg("softmax_y", 0), axis) - .Split(user_op::OpArg("dy", 0), axis) - .Split(user_op::OpArg("mask", 0), axis) - .Split(user_op::OpArg("dropout_mask", 0), axis) - .Split(user_op::OpArg("dx", 0), axis) - .Build(); - } - return Maybe::Ok(); - }); +/*static*/ auto FusedScaleMaskSoftmaxDropoutGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) -> Maybe { + const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc("softmax_y", 0); + const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); + CHECK_EQ_OR_RETURN(dy_desc.shape(), softmax_y_desc.shape()); + CHECK_OR_RETURN(dy_desc.shape() == mask_desc.shape()); + user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + *dx_desc->mut_shape() = dy_desc.shape(); + *dx_desc->mut_is_dynamic() = dy_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ auto FusedScaleMaskSoftmaxDropoutGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) -> Maybe { + return FusedScaleMaskSoftmaxDropoutGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FusedScaleMaskSoftmaxDropoutGradOp::InferDataType(user_op::InferContext* ctx) + -> Maybe { + const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc("softmax_y", 0); + const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); + CHECK_OR_RETURN(dy_desc.data_type() == softmax_y_desc.data_type()); + CHECK_OR_RETURN(mask_desc.data_type() == DataType::kInt8); + user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + *dx_desc->mut_data_type() = dy_desc.data_type(); + return Maybe::Ok(); +} +/*static*/ auto FusedScaleMaskSoftmaxDropoutGradOp::GetSbp(user_op::SbpContext* ctx) + -> Maybe { + const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); + CHECK_GE_OR_RETURN(dy_tensor.shape().NumAxes(), 2); + FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes() - 2) { + ctx->NewBuilder() + .Split(user_op::OpArg("softmax_y", 0), axis) + .Split(user_op::OpArg("dy", 0), axis) + .Split(user_op::OpArg("mask", 0), axis) + .Split(user_op::OpArg("dropout_mask", 0), axis) + .Split(user_op::OpArg("dx", 0), axis) + .Build(); + } + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("fused_scale_mask_softmax_dropout") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -136,6 +132,4 @@ REGISTER_USER_OP_GRAD("fused_scale_mask_softmax_dropout") return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/fused_scale_mask_softmax_op.cpp b/oneflow/user/ops/fused_scale_mask_softmax_op.cpp index 578caa3b495..685c62071f2 100644 --- a/oneflow/user/ops/fused_scale_mask_softmax_op.cpp +++ b/oneflow/user/ops/fused_scale_mask_softmax_op.cpp @@ -14,92 +14,91 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ auto FusedScaleMaskSoftmaxOp::InferLogicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); + CHECK_OR_RETURN(x_desc.shape() == mask_desc.shape()); + *ctx->OutputShape("y", 0) = x_desc.shape(); + *ctx->OutputIsDynamic("y", 0) = x_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ auto FusedScaleMaskSoftmaxOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + return FusedScaleMaskSoftmaxOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FusedScaleMaskSoftmaxOp::InferDataType(user_op::InferContext* ctx) -> Maybe { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); + CHECK_OR_RETURN(mask_desc.data_type() == DataType::kInt8); + *ctx->OutputDType("y", 0) = x_desc.data_type(); + return Maybe::Ok(); +} +/*static*/ auto FusedScaleMaskSoftmaxOp::ModifyInputArg( + const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) + -> Maybe { + user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); + CHECK_OR_RETURN(mask_modifier != nullptr); + mask_modifier->set_requires_grad(false); + return Maybe::Ok(); +} +/*static*/ auto FusedScaleMaskSoftmaxOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + CHECK_GE_OR_RETURN(x_tensor.shape().NumAxes(), 2); + FOR_RANGE(int64_t, axis, 0, x_tensor.shape().NumAxes() - 2) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), axis) + .Split(user_op::OpArg("mask", 0), axis) + .Split(user_op::OpArg("y", 0), axis) + .Build(); + } + return Maybe::Ok(); +} -REGISTER_USER_OP("fused_scale_mask_softmax") - .Input("x") - .Input("mask") - .Output("y") - .Attr("scale_value", 1.0) - .Attr("mask_fill_value", 0.) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); - CHECK_OR_RETURN(x_desc.shape() == mask_desc.shape()); - *ctx->OutputShape("y", 0) = x_desc.shape(); - *ctx->OutputIsDynamic("y", 0) = x_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); - CHECK_OR_RETURN(mask_desc.data_type() == DataType::kInt8); - *ctx->OutputDType("y", 0) = x_desc.data_type(); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); - CHECK_OR_RETURN(mask_modifier != nullptr); - mask_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - CHECK_GE_OR_RETURN(x_tensor.shape().NumAxes(), 2); - FOR_RANGE(int64_t, axis, 0, x_tensor.shape().NumAxes() - 2) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), axis) - .Split(user_op::OpArg("mask", 0), axis) - .Split(user_op::OpArg("y", 0), axis) - .Build(); - } - return Maybe::Ok(); - }); - -REGISTER_USER_OP("fused_scale_mask_softmax_grad") - .Input("y") - .Input("dy") - .Input("mask") - .Output("dx") - .Attr("scale_value") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); - const user_op::TensorDesc& y_desc = ctx->InputTensorDesc("y", 0); - const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); - CHECK_EQ_OR_RETURN(dy_desc.shape(), y_desc.shape()); - CHECK_OR_RETURN(y_desc.shape() == mask_desc.shape()); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); - *dx_desc->mut_shape() = dy_desc.shape(); - *dx_desc->mut_is_dynamic() = dy_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); - const user_op::TensorDesc& y_desc = ctx->InputTensorDesc("y", 0); - const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); - CHECK_OR_RETURN(dy_desc.data_type() == y_desc.data_type()); - CHECK_OR_RETURN(mask_desc.data_type() == DataType::kInt8); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); - *dx_desc->mut_data_type() = dy_desc.data_type(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); - CHECK_GE_OR_RETURN(dy_tensor.shape().NumAxes(), 2); - FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes() - 2) { - ctx->NewBuilder() - .Split(user_op::OpArg("y", 0), axis) - .Split(user_op::OpArg("dy", 0), axis) - .Split(user_op::OpArg("mask", 0), axis) - .Split(user_op::OpArg("dx", 0), axis) - .Build(); - } - return Maybe::Ok(); - }); +/*static*/ auto FusedScaleMaskSoftmaxGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& y_desc = ctx->InputTensorDesc("y", 0); + const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); + CHECK_EQ_OR_RETURN(dy_desc.shape(), y_desc.shape()); + CHECK_OR_RETURN(y_desc.shape() == mask_desc.shape()); + user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + *dx_desc->mut_shape() = dy_desc.shape(); + *dx_desc->mut_is_dynamic() = dy_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ auto FusedScaleMaskSoftmaxGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + return FusedScaleMaskSoftmaxGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FusedScaleMaskSoftmaxGradOp::InferDataType(user_op::InferContext* ctx) + -> Maybe { + const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& y_desc = ctx->InputTensorDesc("y", 0); + const user_op::TensorDesc& mask_desc = ctx->InputTensorDesc("mask", 0); + CHECK_OR_RETURN(dy_desc.data_type() == y_desc.data_type()); + CHECK_OR_RETURN(mask_desc.data_type() == DataType::kInt8); + user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + *dx_desc->mut_data_type() = dy_desc.data_type(); + return Maybe::Ok(); +} +/*static*/ auto FusedScaleMaskSoftmaxGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); + CHECK_GE_OR_RETURN(dy_tensor.shape().NumAxes(), 2); + FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes() - 2) { + ctx->NewBuilder() + .Split(user_op::OpArg("y", 0), axis) + .Split(user_op::OpArg("dy", 0), axis) + .Split(user_op::OpArg("mask", 0), axis) + .Split(user_op::OpArg("dx", 0), axis) + .Build(); + } + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("fused_scale_mask_softmax") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -119,6 +118,4 @@ REGISTER_USER_OP_GRAD("fused_scale_mask_softmax") return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp b/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp index f0905fd8f98..20dead6c8d7 100644 --- a/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp +++ b/oneflow/user/ops/fused_scale_tril_softmax_mask_scale_op.cpp @@ -14,93 +14,88 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { +/*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + *ctx->OutputShape("y", 0) = x_desc.shape(); + *ctx->OutputIsDynamic("y", 0) = x_desc.is_dynamic(); + *ctx->OutputShape("softmax_y", 0) = x_desc.shape(); + *ctx->OutputIsDynamic("softmax_y", 0) = x_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) -> Maybe { + return FusedTrilScaleSoftmaxMaskScaleOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::InferDataType(user_op::InferContext* ctx) + -> Maybe { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + *ctx->OutputDType("y", 0) = x_desc.data_type(); + *ctx->OutputDType("softmax_y", 0) = x_desc.data_type(); + return Maybe::Ok(); +} +/*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::ModifyInputArg( + const user_op::GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) + -> Maybe { + user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); + CHECK_OR_RETURN(mask_modifier != nullptr); + mask_modifier->set_requires_grad(false); + return Maybe::Ok(); +} +/*static*/ auto FusedTrilScaleSoftmaxMaskScaleOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + CHECK_GE_OR_RETURN(x_tensor.shape().NumAxes(), 2); + FOR_RANGE(int64_t, axis, 0, x_tensor.shape().NumAxes() - 2) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), axis) + .Split(user_op::OpArg("mask", 0), axis) + .Split(user_op::OpArg("y", 0), axis) + .Split(user_op::OpArg("softmax_y", 0), axis) + .Build(); + } + return Maybe::Ok(); +} -namespace { - -REGISTER_USER_OP("fused_tril_scale_softmax_mask_scale") - .Input("x") - .Input("mask") - .Output("y") - .Output("softmax_y") - .Attr("diagonal") - .Attr("tril_fill_value", 0) - .Attr("tril_scale_value", 1.0) - .Attr("mask_scale_value", 1.0) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - *ctx->OutputShape("y", 0) = x_desc.shape(); - *ctx->OutputIsDynamic("y", 0) = x_desc.is_dynamic(); - *ctx->OutputShape("softmax_y", 0) = x_desc.shape(); - *ctx->OutputIsDynamic("softmax_y", 0) = x_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - *ctx->OutputDType("y", 0) = x_desc.data_type(); - *ctx->OutputDType("softmax_y", 0) = x_desc.data_type(); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* mask_modifier = GetInputArgModifierFn("mask", 0); - CHECK_OR_RETURN(mask_modifier != nullptr); - mask_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - CHECK_GE_OR_RETURN(x_tensor.shape().NumAxes(), 2); - FOR_RANGE(int64_t, axis, 0, x_tensor.shape().NumAxes() - 2) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), axis) - .Split(user_op::OpArg("mask", 0), axis) - .Split(user_op::OpArg("y", 0), axis) - .Split(user_op::OpArg("softmax_y", 0), axis) - .Build(); - } - return Maybe::Ok(); - }); - -REGISTER_USER_OP("fused_tril_scale_softmax_mask_scale_grad") - .Input("softmax_y") - .Input("dy") - .Input("mask") - .Output("dx") - .Attr("diagonal") - .Attr("tril_scale_value") - .Attr("mask_scale_value") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc("softmax_y", 0); - const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); - CHECK_OR_RETURN(dy_desc.shape() == softmax_y_desc.shape()); - *dx_desc->mut_shape() = dy_desc.shape(); - *dx_desc->mut_is_dynamic() = dy_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc("softmax_y", 0); - const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); - CHECK_OR_RETURN(dy_desc.data_type() == softmax_y_desc.data_type()); - *dx_desc->mut_data_type() = dy_desc.data_type(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); - CHECK_GE_OR_RETURN(dy_tensor.shape().NumAxes(), 2); - FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes() - 2) { - ctx->NewBuilder() - .Split(user_op::OpArg("softmax_y", 0), axis) - .Split(user_op::OpArg("dy", 0), axis) - .Split(user_op::OpArg("mask", 0), axis) - .Split(user_op::OpArg("dx", 0), axis) - .Build(); - } - return Maybe::Ok(); - }); +/*static*/ auto FusedTrilScaleSoftmaxMaskScaleGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) -> Maybe { + const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc("softmax_y", 0); + const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); + user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + CHECK_OR_RETURN(dy_desc.shape() == softmax_y_desc.shape()); + *dx_desc->mut_shape() = dy_desc.shape(); + *dx_desc->mut_is_dynamic() = dy_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ auto FusedTrilScaleSoftmaxMaskScaleGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) -> Maybe { + return FusedTrilScaleSoftmaxMaskScaleGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FusedTrilScaleSoftmaxMaskScaleGradOp::InferDataType(user_op::InferContext* ctx) + -> Maybe { + const user_op::TensorDesc& softmax_y_desc = ctx->InputTensorDesc("softmax_y", 0); + const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); + user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + CHECK_OR_RETURN(dy_desc.data_type() == softmax_y_desc.data_type()); + *dx_desc->mut_data_type() = dy_desc.data_type(); + return Maybe::Ok(); +} +/*static*/ auto FusedTrilScaleSoftmaxMaskScaleGradOp::GetSbp(user_op::SbpContext* ctx) + -> Maybe { + const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); + CHECK_GE_OR_RETURN(dy_tensor.shape().NumAxes(), 2); + FOR_RANGE(int64_t, axis, 0, dy_tensor.shape().NumAxes() - 2) { + ctx->NewBuilder() + .Split(user_op::OpArg("softmax_y", 0), axis) + .Split(user_op::OpArg("dy", 0), axis) + .Split(user_op::OpArg("mask", 0), axis) + .Split(user_op::OpArg("dx", 0), axis) + .Build(); + } + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("fused_tril_scale_softmax_mask_scale") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -123,6 +118,4 @@ REGISTER_USER_OP_GRAD("fused_tril_scale_softmax_mask_scale") return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp b/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp index 0748daa1b22..232a78189c9 100644 --- a/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp +++ b/oneflow/user/ops/fused_self_attention_query_mul_key_and_value_ops.cpp @@ -14,110 +14,113 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("fused_self_attention_query_mul_key_and_value") - .Input("hidden_states") - .Output("query_mul_key") - .Output("value") - .Attr("head_size") - .Attr("alpha") - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const DataType& dtype = ctx->InputDType("hidden_states", 0); - *ctx->OutputDType("query_mul_key", 0) = dtype; - *ctx->OutputDType("value", 0) = dtype; - return Maybe::Ok(); - }) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_OR_RETURN(!(ctx->InputIsDynamic("hidden_states", 0))); - int64_t head_size = ctx->Attr("head_size"); - const Shape& hidden_states_shape = ctx->InputShape("hidden_states", 0); - // hidden_states_shape (seq_len, batch_size, hidden_size) - // layout is (seq_len, batch_size, num_heads, 3, head_size) - // for example shape (1024, 4, 12, 3, 64) -> (1024, 4, 12, 192) which stride is (9216, 2304, - // 192, 1) - CHECK_EQ_OR_RETURN(hidden_states_shape.NumAxes(), 3); - int64_t seq_len = hidden_states_shape.At(0); - int64_t batch_size = hidden_states_shape.At(1); - int64_t hidden_size = hidden_states_shape.At(2); - CHECK_EQ_OR_RETURN(hidden_size % (head_size * 3), 0); - int64_t num_heads = hidden_size / (head_size * 3); +/*static*/ auto FusedSelfAttentionQueryMulKeyAndValueOp::InferDataType(user_op::InferContext* ctx) + -> Maybe { + const DataType& dtype = ctx->InputDType("hidden_states", 0); + *ctx->OutputDType("query_mul_key", 0) = dtype; + *ctx->OutputDType("value", 0) = dtype; + return Maybe::Ok(); +} +/*static*/ auto FusedSelfAttentionQueryMulKeyAndValueOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) -> Maybe { + CHECK_OR_RETURN(!(ctx->InputIsDynamic("hidden_states", 0))); + int64_t head_size = ctx->Attr("head_size"); + const Shape& hidden_states_shape = ctx->InputShape("hidden_states", 0); + // hidden_states_shape (seq_len, batch_size, hidden_size) + // layout is (seq_len, batch_size, num_heads, 3, head_size) + // for example shape (1024, 4, 12, 3, 64) -> (1024, 4, 12, 192) which stride is (9216, 2304, + // 192, 1) + CHECK_EQ_OR_RETURN(hidden_states_shape.NumAxes(), 3); + int64_t seq_len = hidden_states_shape.At(0); + int64_t batch_size = hidden_states_shape.At(1); + int64_t hidden_size = hidden_states_shape.At(2); + CHECK_EQ_OR_RETURN(hidden_size % (head_size * 3), 0); + int64_t num_heads = hidden_size / (head_size * 3); - *ctx->OutputShape("query_mul_key", 0) = Shape({batch_size, num_heads, seq_len, seq_len}); - *ctx->OutputShape("value", 0) = Shape({batch_size, num_heads, seq_len, head_size}); + *ctx->OutputShape("query_mul_key", 0) = Shape({batch_size, num_heads, seq_len, seq_len}); + *ctx->OutputShape("value", 0) = Shape({batch_size, num_heads, seq_len, head_size}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("hidden_states", 0), 1) - .Split(user_op::OpArg("query_mul_key", 0), 0) - .Split(user_op::OpArg("value", 0), 0) - .Build(); - ctx->NewBuilder() - .Split(user_op::OpArg("hidden_states", 0), 2) - .Split(user_op::OpArg("query_mul_key", 0), 1) - .Split(user_op::OpArg("value", 0), 1) - .Build(); - return Maybe::Ok(); - }); + return Maybe::Ok(); +} +/*static*/ auto FusedSelfAttentionQueryMulKeyAndValueOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) -> Maybe { + return FusedSelfAttentionQueryMulKeyAndValueOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FusedSelfAttentionQueryMulKeyAndValueOp::GetSbp(user_op::SbpContext* ctx) + -> Maybe { + ctx->NewBuilder() + .Split(user_op::OpArg("hidden_states", 0), 1) + .Split(user_op::OpArg("query_mul_key", 0), 0) + .Split(user_op::OpArg("value", 0), 0) + .Build(); + ctx->NewBuilder() + .Split(user_op::OpArg("hidden_states", 0), 2) + .Split(user_op::OpArg("query_mul_key", 0), 1) + .Split(user_op::OpArg("value", 0), 1) + .Build(); + return Maybe::Ok(); +} -REGISTER_USER_OP("fused_self_attention_query_mul_key_and_value_grad") - .Input("query_mul_key_grad") - .Input("value_grad") - .Input("hidden_states") - .Output("hidden_states_grad") - .Attr("alpha") - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const DataType& dtype = ctx->InputDType("query_mul_key_grad", 0); - CHECK_EQ_OR_RETURN(ctx->InputDType("value_grad", 0), dtype); - *ctx->OutputDType("hidden_states_grad", 0) = dtype; - return Maybe::Ok(); - }) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_OR_RETURN(!(ctx->InputIsDynamic("query_mul_key_grad", 0))); - CHECK_OR_RETURN(!(ctx->InputIsDynamic("value_grad", 0))); - const Shape& h_shape = ctx->InputShape("hidden_states", 0); - const Shape& qmk_grad_shape = ctx->InputShape("query_mul_key_grad", 0); - const Shape& v_grad_shape = ctx->InputShape("value_grad", 0); - CHECK_EQ_OR_RETURN(h_shape.NumAxes(), 3); - CHECK_EQ_OR_RETURN(qmk_grad_shape.NumAxes(), 4); - CHECK_EQ_OR_RETURN(v_grad_shape.NumAxes(), 4); - // hidden_states shape (s, b, H) - int64_t seq_len = h_shape.At(0); - int64_t batch_size = h_shape.At(1); - int64_t hidden_size = h_shape.At(2); - // value grad shape (b, n, s, h) - int64_t num_heads = v_grad_shape.At(1); - int64_t head_size = v_grad_shape.At(3); - CHECK_EQ_OR_RETURN(v_grad_shape.At(0), batch_size); - CHECK_EQ_OR_RETURN(v_grad_shape.At(2), seq_len); - CHECK_EQ_OR_RETURN(hidden_size, num_heads * 3 * head_size); - // qmk grad shape (b, n, sq, sk) - CHECK_EQ_OR_RETURN(qmk_grad_shape.At(0), batch_size); - CHECK_EQ_OR_RETURN(qmk_grad_shape.At(1), num_heads); - CHECK_EQ_OR_RETURN(qmk_grad_shape.At(2), seq_len); - CHECK_EQ_OR_RETURN(qmk_grad_shape.At(3), seq_len); +/*static*/ auto FusedSelfAttentionQueryMulKeyAndValueGradOp::InferDataType( + user_op::InferContext* ctx) -> Maybe { + const DataType& dtype = ctx->InputDType("query_mul_key_grad", 0); + CHECK_EQ_OR_RETURN(ctx->InputDType("value_grad", 0), dtype); + *ctx->OutputDType("hidden_states_grad", 0) = dtype; + return Maybe::Ok(); +} +/*static*/ auto FusedSelfAttentionQueryMulKeyAndValueGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) -> Maybe { + CHECK_OR_RETURN(!(ctx->InputIsDynamic("query_mul_key_grad", 0))); + CHECK_OR_RETURN(!(ctx->InputIsDynamic("value_grad", 0))); + const Shape& h_shape = ctx->InputShape("hidden_states", 0); + const Shape& qmk_grad_shape = ctx->InputShape("query_mul_key_grad", 0); + const Shape& v_grad_shape = ctx->InputShape("value_grad", 0); + CHECK_EQ_OR_RETURN(h_shape.NumAxes(), 3); + CHECK_EQ_OR_RETURN(qmk_grad_shape.NumAxes(), 4); + CHECK_EQ_OR_RETURN(v_grad_shape.NumAxes(), 4); + // hidden_states shape (s, b, H) + int64_t seq_len = h_shape.At(0); + int64_t batch_size = h_shape.At(1); + int64_t hidden_size = h_shape.At(2); + // value grad shape (b, n, s, h) + int64_t num_heads = v_grad_shape.At(1); + int64_t head_size = v_grad_shape.At(3); + CHECK_EQ_OR_RETURN(v_grad_shape.At(0), batch_size); + CHECK_EQ_OR_RETURN(v_grad_shape.At(2), seq_len); + CHECK_EQ_OR_RETURN(hidden_size, num_heads * 3 * head_size); + // qmk grad shape (b, n, sq, sk) + CHECK_EQ_OR_RETURN(qmk_grad_shape.At(0), batch_size); + CHECK_EQ_OR_RETURN(qmk_grad_shape.At(1), num_heads); + CHECK_EQ_OR_RETURN(qmk_grad_shape.At(2), seq_len); + CHECK_EQ_OR_RETURN(qmk_grad_shape.At(3), seq_len); - *ctx->OutputShape("hidden_states_grad", 0) = h_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("query_mul_key_grad", 0), 0) - .Split(user_op::OpArg("value_grad", 0), 0) - .Split(user_op::OpArg("hidden_states", 0), 1) - .Split(user_op::OpArg("hidden_states_grad", 0), 1) - .Build(); - ctx->NewBuilder() - .Split(user_op::OpArg("query_mul_key_grad", 0), 1) - .Split(user_op::OpArg("value_grad", 0), 1) - .Split(user_op::OpArg("hidden_states", 0), 2) - .Split(user_op::OpArg("hidden_states_grad", 0), 2) - .Build(); - return Maybe::Ok(); - }); + *ctx->OutputShape("hidden_states_grad", 0) = h_shape; + return Maybe::Ok(); +} +/*static*/ auto FusedSelfAttentionQueryMulKeyAndValueGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) -> Maybe { + return FusedSelfAttentionQueryMulKeyAndValueGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto FusedSelfAttentionQueryMulKeyAndValueGradOp::GetSbp(user_op::SbpContext* ctx) + -> Maybe { + ctx->NewBuilder() + .Split(user_op::OpArg("query_mul_key_grad", 0), 0) + .Split(user_op::OpArg("value_grad", 0), 0) + .Split(user_op::OpArg("hidden_states", 0), 1) + .Split(user_op::OpArg("hidden_states_grad", 0), 1) + .Build(); + ctx->NewBuilder() + .Split(user_op::OpArg("query_mul_key_grad", 0), 1) + .Split(user_op::OpArg("value_grad", 0), 1) + .Split(user_op::OpArg("hidden_states", 0), 2) + .Split(user_op::OpArg("hidden_states_grad", 0), 2) + .Build(); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("fused_self_attention_query_mul_key_and_value") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { diff --git a/oneflow/user/ops/gather_op.cpp b/oneflow/user/ops/gather_op.cpp index 47045ef4c0c..87ded29ab9c 100644 --- a/oneflow/user/ops/gather_op.cpp +++ b/oneflow/user/ops/gather_op.cpp @@ -14,80 +14,79 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("gather") - .Input("in") - .Input("indices") - .Output("out") - .Attr("axis") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - CHECK_GT_OR_RETURN(in.shape().NumAxes(), 0); - const int64_t axis = ctx->Attr("axis"); - const user_op::TensorDesc& indices = ctx->InputTensorDesc("indices", 0); - CHECK_GT_OR_RETURN(indices.shape().NumAxes(), 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); +/*static*/ auto GatherOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + CHECK_GT_OR_RETURN(in.shape().NumAxes(), 0); + const int64_t axis = ctx->Attr("axis"); + const user_op::TensorDesc& indices = ctx->InputTensorDesc("indices", 0); + CHECK_GT_OR_RETURN(indices.shape().NumAxes(), 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - DimVector dim_vec; - dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin(), - in.shape().dim_vec().cbegin() + axis); - dim_vec.insert(dim_vec.end(), indices.shape().dim_vec().cbegin(), - indices.shape().dim_vec().cend()); - dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin() + axis + 1, - in.shape().dim_vec().end()); - *out->mut_shape() = Shape(dim_vec); - out->set_is_dynamic(indices.is_dynamic() || in.is_dynamic()); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK_OR_RETURN(indices_modifier != nullptr); - indices_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const int64_t in_num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape().NumAxes(); - const int64_t indices_num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0).shape().NumAxes(); - const int64_t gather_axis = ctx->Attr("axis"); - CHECK_GE_OR_RETURN(gather_axis, 0); - CHECK_LT_OR_RETURN(gather_axis, in_num_axes); - FOR_RANGE(int64_t, i, 0, indices_num_axes) { - ctx->NewBuilder() - .Split(user_op::OpArg("indices", 0), i) - .Broadcast(user_op::OpArg("in", 0)) - .Split(user_op::OpArg("out", 0), gather_axis + i) - .Build(); - } - FOR_RANGE(int64_t, i, 0, in_num_axes) { - if (i == gather_axis) { - ctx->NewBuilder() - .Broadcast(user_op::OpArg("indices", 0)) - .Split(user_op::OpArg("in", 0), i) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - } else { - ctx->NewBuilder() - .Broadcast(user_op::OpArg("indices", 0)) - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i < gather_axis ? i : i + indices_num_axes - 1) - .Build(); - } - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - const user_op::TensorDesc& indices = ctx->InputTensorDesc("indices", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(IsIndexDataType(indices.data_type())); - *out->mut_data_type() = in.data_type(); - return Maybe::Ok(); - }); + DimVector dim_vec; + dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin(), + in.shape().dim_vec().cbegin() + axis); + dim_vec.insert(dim_vec.end(), indices.shape().dim_vec().cbegin(), + indices.shape().dim_vec().cend()); + dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin() + axis + 1, + in.shape().dim_vec().end()); + *out->mut_shape() = Shape(dim_vec); + out->set_is_dynamic(indices.is_dynamic() || in.is_dynamic()); + return Maybe::Ok(); +} +/*static*/ auto GatherOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { + return GatherOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto GatherOp::ModifyInputArg(const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper&) -> Maybe { + user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); + CHECK_OR_RETURN(indices_modifier != nullptr); + indices_modifier->set_requires_grad(false); + return Maybe::Ok(); +} +/*static*/ auto GatherOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + const int64_t in_num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape().NumAxes(); + const int64_t indices_num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0).shape().NumAxes(); + const int64_t gather_axis = ctx->Attr("axis"); + CHECK_GE_OR_RETURN(gather_axis, 0); + CHECK_LT_OR_RETURN(gather_axis, in_num_axes); + FOR_RANGE(int64_t, i, 0, indices_num_axes) { + ctx->NewBuilder() + .Split(user_op::OpArg("indices", 0), i) + .Broadcast(user_op::OpArg("in", 0)) + .Split(user_op::OpArg("out", 0), gather_axis + i) + .Build(); + } + FOR_RANGE(int64_t, i, 0, in_num_axes) { + if (i == gather_axis) { + ctx->NewBuilder() + .Broadcast(user_op::OpArg("indices", 0)) + .Split(user_op::OpArg("in", 0), i) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + } else { + ctx->NewBuilder() + .Broadcast(user_op::OpArg("indices", 0)) + .Split(user_op::OpArg("in", 0), i) + .Split(user_op::OpArg("out", 0), i < gather_axis ? i : i + indices_num_axes - 1) + .Build(); + } + } + return Maybe::Ok(); +} +/*static*/ auto GatherOp::InferDataType(user_op::InferContext* ctx) -> Maybe { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + const user_op::TensorDesc& indices = ctx->InputTensorDesc("indices", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(IsIndexDataType(indices.data_type())); + *out->mut_data_type() = in.data_type(); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("gather").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/gelu_op.cpp b/oneflow/user/ops/gelu_op.cpp index d2f374052b5..39f12592c23 100644 --- a/oneflow/user/ops/gelu_op.cpp +++ b/oneflow/user/ops/gelu_op.cpp @@ -14,66 +14,63 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("gelu") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); - *out_shape = in_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ auto GeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { + const Shape& in_shape = ctx->InputShape("in", 0); + Shape* out_shape = ctx->OutputShape("out", 0); + *out_shape = in_shape; + return Maybe::Ok(); +} +/*static*/ auto GeluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { + return GeluOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto GeluOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ auto GeluOp::InferDataType(user_op::InferContext* ctx) -> Maybe { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("gelu_grad") - .Input("x") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(dy_shape == x_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - ctx->NewBuilder() - .Broadcast(user_op::OpArg("x", 0)) - .PartialSum(user_op::OpArg("dy", 0)) - .PartialSum(user_op::OpArg("dx", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ auto GeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(dy_shape == x_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} +/*static*/ auto GeluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { + return GeluGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto GeluGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + ctx->NewBuilder() + .Broadcast(user_op::OpArg("x", 0)) + .PartialSum(user_op::OpArg("dy", 0)) + .PartialSum(user_op::OpArg("dx", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ auto GeluGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe { + CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("gelu").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/generate_random_batch_permutation_indices_op.cpp b/oneflow/user/ops/generate_random_batch_permutation_indices_op.cpp index 5de4f2e6280..73b7dcb52eb 100644 --- a/oneflow/user/ops/generate_random_batch_permutation_indices_op.cpp +++ b/oneflow/user/ops/generate_random_batch_permutation_indices_op.cpp @@ -15,34 +15,32 @@ limitations under the License. */ #include #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("generate_random_batch_permutation_indices") - .Input("x") - .Output("y") - .Attr("seed") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("y", 0) = Shape({ctx->InputShape("x", 0).At(0)}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .PartialSum(user_op::OpArg("x", 0)) - .Broadcast(user_op::OpArg("y", 0)) - .Build(); - const auto& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Broadcast(user_op::OpArg("y", 0)) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = DataType::kInt32; - return Maybe::Ok(); - }); +/*static*/ auto GenerateRandomBatchPermutationIndicesOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) -> Maybe { + *ctx->OutputShape("y", 0) = Shape({ctx->InputShape("x", 0).At(0)}); + return Maybe::Ok(); +} +/*static*/ auto GenerateRandomBatchPermutationIndicesOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) -> Maybe { + return GenerateRandomBatchPermutationIndicesOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto GenerateRandomBatchPermutationIndicesOp::GetSbp(user_op::SbpContext* ctx) + -> Maybe { + ctx->NewBuilder().PartialSum(user_op::OpArg("x", 0)).Broadcast(user_op::OpArg("y", 0)).Build(); + const auto& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Broadcast(user_op::OpArg("y", 0)).Build(); + } + return Maybe::Ok(); +} +/*static*/ auto GenerateRandomBatchPermutationIndicesOp::InferDataType(user_op::InferContext* ctx) + -> Maybe { + *ctx->OutputDType("y", 0) = DataType::kInt32; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/gpt_data_loader_op.cpp b/oneflow/user/ops/gpt_data_loader_op.cpp index a67f0d4077a..ab66be504e1 100644 --- a/oneflow/user/ops/gpt_data_loader_op.cpp +++ b/oneflow/user/ops/gpt_data_loader_op.cpp @@ -14,51 +14,42 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("megatron_gpt_mmap_data_loader") - .OptionalInput("iteration") - .Output("out") - .Attr("data_file_prefix") - .Attr("seq_length") - .Attr("label_length", 1) - .Attr("num_samples") - .Attr("batch_size") - .Attr("dtype") - .Attr>("split_sizes") - .Attr("split_index") - .Attr("shuffle") - .Attr("random_seed") - .Attr>("nd_sbp") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - int64_t batch_size = ctx->Attr("batch_size"); - int64_t sample_len = ctx->Attr("seq_length") + ctx->Attr("label_length"); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_shape() = Shape({batch_size, sample_len}); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputTensorDesc("out", 0)->mut_data_type() = ctx->Attr("dtype"); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - cfg::SbpParallel default_sbp; - default_sbp.mutable_split_parallel()->set_axis(0); - return user_op::InferNdSbp4SrcOp(ctx, default_sbp); - }) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> Maybe { - if (!conf.has_input("iteration", 0)) { return Maybe::Ok(); } - user_op::InputArgModifier* input_modifier = GetInputArgModifierFn("iteration", 0); - CHECK_OR_RETURN(input_modifier != nullptr); - input_modifier->set_is_mutable(true); - input_modifier->set_requires_grad(false); - return Maybe::Ok(); - }); +/*static*/ auto MegatronGptMmapDataLoaderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + int64_t batch_size = ctx->Attr("batch_size"); + int64_t sample_len = ctx->Attr("seq_length") + ctx->Attr("label_length"); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_shape() = Shape({batch_size, sample_len}); + return Maybe::Ok(); +} +/*static*/ auto MegatronGptMmapDataLoaderOp::InferDataType(user_op::InferContext* ctx) + -> Maybe { + *ctx->OutputTensorDesc("out", 0)->mut_data_type() = ctx->Attr("dtype"); + return Maybe::Ok(); +} +/*static*/ auto MegatronGptMmapDataLoaderOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} +/*static*/ auto MegatronGptMmapDataLoaderOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) + -> Maybe { + cfg::SbpParallel default_sbp; + default_sbp.mutable_split_parallel()->set_axis(0); + return user_op::InferNdSbp4SrcOp(ctx, default_sbp); +} +/*static*/ auto MegatronGptMmapDataLoaderOp::ModifyInputArg( + const user_op::GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) -> Maybe { + if (!conf.has_input("iteration", 0)) { return Maybe::Ok(); } + user_op::InputArgModifier* input_modifier = GetInputArgModifierFn("iteration", 0); + CHECK_OR_RETURN(input_modifier != nullptr); + input_modifier->set_is_mutable(true); + input_modifier->set_requires_grad(false); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/grid_sample_op.cpp b/oneflow/user/ops/grid_sample_op.cpp index c415d858987..d9c81470a7d 100644 --- a/oneflow/user/ops/grid_sample_op.cpp +++ b/oneflow/user/ops/grid_sample_op.cpp @@ -14,13 +14,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -Maybe CheckAttr(const user_op::UserOpDefWrapper& def, - const user_op::UserOpConfWrapper& conf) { +Maybe GridSampleOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { bool pass_checked = true; std::stringstream err; err << "Illegal value for " << conf.op_type_name() << " op " << conf.op_name() << ": "; @@ -45,110 +44,103 @@ Maybe CheckAttr(const user_op::UserOpDefWrapper& def, } } -} // namespace - -REGISTER_USER_OP("grid_sample") - .Input("input") - .Input("grid") - .Output("output") - .Attr("interpolation_mode") - .Attr("padding_mode") - .Attr("align_corners") - .SetCheckAttrFn(CheckAttr) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& input = ctx->InputTensorDesc("input", 0); - const user_op::TensorDesc& grid = ctx->InputTensorDesc("grid", 0); - user_op::TensorDesc& output = *(ctx->OutputTensorDesc("output", 0)); - // Only support 4D or 5D input with NCHW layout - // For 4D grid: input = { N, C, H_in, W_in }, - // grid = { N, H_out, W_out, 2 } - // output = { N, C, H_out, W_out } - // For 5D grid: input = { N, C, D_in, H_in, W_in }, - // grid = { N, D_out, H_out, W_out, 3 } - // output = { N, C, D_out, H_out, W_out } - const Shape& input_shape = input.shape(); - const Shape& grid_shape = grid.shape(); +/*static*/ auto GridSampleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) -> Maybe { + const user_op::TensorDesc& input = ctx->InputTensorDesc("input", 0); + const user_op::TensorDesc& grid = ctx->InputTensorDesc("grid", 0); + user_op::TensorDesc& output = *(ctx->OutputTensorDesc("output", 0)); + // Only support 4D or 5D input with NCHW layout + // For 4D grid: input = { N, C, H_in, W_in }, + // grid = { N, H_out, W_out, 2 } + // output = { N, C, H_out, W_out } + // For 5D grid: input = { N, C, D_in, H_in, W_in }, + // grid = { N, D_out, H_out, W_out, 3 } + // output = { N, C, D_out, H_out, W_out } + const Shape& input_shape = input.shape(); + const Shape& grid_shape = grid.shape(); + + bool is_4d_input = true; + if (input_shape.NumAxes() == 4) { + CHECK_EQ_OR_RETURN(grid_shape.NumAxes(), 4) << "Grid and input MUST have same dimention"; + CHECK_EQ_OR_RETURN(grid_shape.At(3), 2) << "Grid shape MUST (N, H_out, W_out, 2)"; + is_4d_input = true; + } else if (input_shape.NumAxes() == 5) { + CHECK_EQ_OR_RETURN(grid_shape.NumAxes(), 5) << "Grid and input MUST have same dimention"; + CHECK_EQ_OR_RETURN(grid_shape.At(4), 3) << "Grid shape MUST (N, H_out, W_out, 3)"; + if (ctx->Attr("interpolation_mode") == "bicubic") { + oneflow::Error::CheckFailedError() << "Mode='bicubic' supports only 4-D input"; + } + is_4d_input = false; + } else { + CHECK_OR_RETURN(false) << "MUST be 4D or 5D input"; + } + *output.mut_is_dynamic() = grid.is_dynamic(); + if (is_4d_input) { + *(output.mut_shape()) = {input_shape.At(0), input_shape.At(1), grid_shape.At(1), + grid_shape.At(2)}; + } else { + *(output.mut_shape()) = {input_shape.At(0), input_shape.At(1), grid_shape.At(1), + grid_shape.At(2), grid_shape.At(3)}; + } + return Maybe::Ok(); +} +/*static*/ auto GridSampleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) -> Maybe { + return GridSampleOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto GridSampleOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + ctx->NewBuilder() + .Split(user_op::OpArg("input", 0), 0) + .Split(user_op::OpArg("grid", 0), 0) + .Split(user_op::OpArg("output", 0), 0) + .Build(); + ctx->NewBuilder() + .Split(user_op::OpArg("input", 0), 1) + .Broadcast(user_op::OpArg("grid", 0)) + .Split(user_op::OpArg("output", 0), 1) + .Build(); + return Maybe::Ok(); +} +/*static*/ auto GridSampleOp::InferDataType(user_op::InferContext* ctx) -> Maybe { + *ctx->OutputDType("output", 0) = ctx->InputDType("input", 0); + return Maybe::Ok(); +} - bool is_4d_input = true; - if (input_shape.NumAxes() == 4) { - CHECK_EQ_OR_RETURN(grid_shape.NumAxes(), 4) << "Grid and input MUST have same dimention"; - CHECK_EQ_OR_RETURN(grid_shape.At(3), 2) << "Grid shape MUST (N, H_out, W_out, 2)"; - is_4d_input = true; - } else if (input_shape.NumAxes() == 5) { - CHECK_EQ_OR_RETURN(grid_shape.NumAxes(), 5) << "Grid and input MUST have same dimention"; - CHECK_EQ_OR_RETURN(grid_shape.At(4), 3) << "Grid shape MUST (N, H_out, W_out, 3)"; - if (ctx->Attr("interpolation_mode") == "bicubic") { - oneflow::Error::CheckFailedError() << "Mode='bicubic' supports only 4-D input"; - } - is_4d_input = false; - } else { - CHECK_OR_RETURN(false) << "MUST be 4D or 5D input"; - } - *output.mut_is_dynamic() = grid.is_dynamic(); - if (is_4d_input) { - *(output.mut_shape()) = {input_shape.At(0), input_shape.At(1), grid_shape.At(1), - grid_shape.At(2)}; - } else { - *(output.mut_shape()) = {input_shape.At(0), input_shape.At(1), grid_shape.At(1), - grid_shape.At(2), grid_shape.At(3)}; - } - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("input", 0), 0) - .Split(user_op::OpArg("grid", 0), 0) - .Split(user_op::OpArg("output", 0), 0) - .Build(); - ctx->NewBuilder() - .Split(user_op::OpArg("input", 0), 1) - .Broadcast(user_op::OpArg("grid", 0)) - .Split(user_op::OpArg("output", 0), 1) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("output", 0) = ctx->InputDType("input", 0); - return Maybe::Ok(); - }); +Maybe GridSampleGradOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + return GridSampleOp::CheckAttr(def, conf); +} -REGISTER_USER_OP("grid_sample_grad") - .Input("doutput") - .Input("input") - .Input("grid") - .Output("dinput") - .Output("dgrid") - .Attr("interpolation_mode") - .Attr("padding_mode") - .Attr("align_corners") - .SetCheckAttrFn(CheckAttr) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *(ctx->OutputTensorDesc("dinput", 0)->mut_shape()) = ctx->InputTensorDesc("input", 0).shape(); - *(ctx->OutputTensorDesc("dgrid", 0)->mut_shape()) = ctx->InputTensorDesc("grid", 0).shape(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("doutput", 0), 0) - .Split(user_op::OpArg("input", 0), 0) - .Split(user_op::OpArg("grid", 0), 0) - .Split(user_op::OpArg("dinput", 0), 0) - .Split(user_op::OpArg("dgrid", 0), 0) - .Build(); - ctx->NewBuilder() - .Split(user_op::OpArg("doutput", 0), 1) - .Split(user_op::OpArg("input", 0), 1) - .Broadcast(user_op::OpArg("grid", 0)) - .Split(user_op::OpArg("dinput", 0), 1) - .Broadcast(user_op::OpArg("dgrid", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dinput", 0) = ctx->InputDType("input", 0); - *ctx->OutputDType("dgrid", 0) = ctx->InputDType("grid", 0); - return Maybe::Ok(); - }); +/*static*/ auto GridSampleGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + *(ctx->OutputTensorDesc("dinput", 0)->mut_shape()) = ctx->InputTensorDesc("input", 0).shape(); + *(ctx->OutputTensorDesc("dgrid", 0)->mut_shape()) = ctx->InputTensorDesc("grid", 0).shape(); + return Maybe::Ok(); +} +/*static*/ auto GridSampleGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) + -> Maybe { + return GridSampleGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ auto GridSampleGradOp::GetSbp(user_op::SbpContext* ctx) -> Maybe { + ctx->NewBuilder() + .Split(user_op::OpArg("doutput", 0), 0) + .Split(user_op::OpArg("input", 0), 0) + .Split(user_op::OpArg("grid", 0), 0) + .Split(user_op::OpArg("dinput", 0), 0) + .Split(user_op::OpArg("dgrid", 0), 0) + .Build(); + ctx->NewBuilder() + .Split(user_op::OpArg("doutput", 0), 1) + .Split(user_op::OpArg("input", 0), 1) + .Broadcast(user_op::OpArg("grid", 0)) + .Split(user_op::OpArg("dinput", 0), 1) + .Broadcast(user_op::OpArg("dgrid", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ auto GridSampleGradOp::InferDataType(user_op::InferContext* ctx) -> Maybe { + *ctx->OutputDType("dinput", 0) = ctx->InputDType("input", 0); + *ctx->OutputDType("dgrid", 0) = ctx->InputDType("grid", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("grid_sample") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/hardsigmoid_op.cpp b/oneflow/user/ops/hardsigmoid_op.cpp index cdf43671ab2..887614425ac 100644 --- a/oneflow/user/ops/hardsigmoid_op.cpp +++ b/oneflow/user/ops/hardsigmoid_op.cpp @@ -14,63 +14,64 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/* static */ Maybe HardsigmoidOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + Shape* out_shape = ctx->OutputShape("out", 0); + *out_shape = in_shape; + return Maybe::Ok(); +} -REGISTER_USER_OP("hardsigmoid") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); - *out_shape = in_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe HardsigmoidOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} -REGISTER_USER_OP("hardsigmoid_grad") - .Input("x") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(dy_shape == x_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/* static */ Maybe HardsigmoidOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe HardsigmoidOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe HardsigmoidGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(dy_shape == x_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} + +/*static*/ Maybe HardsigmoidGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe HardsigmoidGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe HardsigmoidGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("hardsigmoid") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { @@ -89,6 +90,4 @@ REGISTER_USER_OP_GRAD("hardsigmoid") return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/hardswish_op.cpp b/oneflow/user/ops/hardswish_op.cpp index 45d0ebe230c..f7dfbc5c870 100644 --- a/oneflow/user/ops/hardswish_op.cpp +++ b/oneflow/user/ops/hardswish_op.cpp @@ -14,61 +14,62 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/* static */ Maybe HardswishOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("hardswish") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe HardswishOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} -REGISTER_USER_OP("hardswish_grad") - .Input("x") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(dy_shape == x_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/* static */ Maybe HardswishOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe HardswishOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe HardswishGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(dy_shape == x_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} + +/*static*/ Maybe HardswishGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe HardswishGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe HardswishGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("hardswish") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { @@ -87,6 +88,4 @@ REGISTER_USER_OP_GRAD("hardswish") return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/hardtanh_op.cpp b/oneflow/user/ops/hardtanh_op.cpp index 2962c49e99e..2d5208c7b0b 100644 --- a/oneflow/user/ops/hardtanh_op.cpp +++ b/oneflow/user/ops/hardtanh_op.cpp @@ -14,73 +14,70 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/* static */ Maybe HardtanhOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + Shape* out_shape = ctx->OutputShape("out", 0); + *out_shape = in_shape; + double min_val = ctx->Attr("min_val"); + double max_val = ctx->Attr("max_val"); + CHECK_LE_OR_RETURN(min_val, max_val); + return Maybe::Ok(); +} -REGISTER_USER_OP("hardtanh") - .Input("in") - .Attr("min_val") - .Attr("max_val") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); - *out_shape = in_shape; - double min_val = ctx->Attr("min_val"); - double max_val = ctx->Attr("max_val"); - CHECK_LE_OR_RETURN(min_val, max_val); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe HardtanhOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} -REGISTER_USER_OP("hardtanh_grad") - .Input("y") - .Input("dy") - .Attr("min_val") - .Attr("max_val") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& y_shape = ctx->InputShape("y", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(dy_shape == y_shape); - *dx_shape = dy_shape; - double min_val = ctx->Attr("min_val"); - double max_val = ctx->Attr("max_val"); - CHECK_LE_OR_RETURN(min_val, max_val); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); - FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("y", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("y", 0), ctx->InputDType("dy", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("y", 0); - return Maybe::Ok(); - }); +/* static */ Maybe HardtanhOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe HardtanhOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe HardtanhGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& y_shape = ctx->InputShape("y", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(dy_shape == y_shape); + *dx_shape = dy_shape; + double min_val = ctx->Attr("min_val"); + double max_val = ctx->Attr("max_val"); + CHECK_LE_OR_RETURN(min_val, max_val); + return Maybe::Ok(); +} + +/*static*/ Maybe HardtanhGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe HardtanhGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); + FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("y", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe HardtanhGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("y", 0), ctx->InputDType("dy", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("y", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("hardtanh") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { @@ -101,6 +98,4 @@ REGISTER_USER_OP_GRAD("hardtanh") return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/hierarchical_parallel_cast_op.cpp b/oneflow/user/ops/hierarchical_parallel_cast_op.cpp index 8ac12f2573a..8fa5b36bf49 100644 --- a/oneflow/user/ops/hierarchical_parallel_cast_op.cpp +++ b/oneflow/user/ops/hierarchical_parallel_cast_op.cpp @@ -15,64 +15,78 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/operator.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("hierarchical_parallel_cast") - .Input("in") - .Output("out") - .Attr>("nd_sbp") - .Attr("grad_mode") - .Attr>("grad_nd_sbp") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - const auto& conf = ctx->user_op_conf().attr>("nd_sbp"); - CHECK_EQ_OR_RETURN(conf.size(), parallel_hierarchy.NumAxes()); - for (const std::string& sbp_str : conf) { - cfg::SbpParallel sbp_parallel; - CHECK_OR_RETURN(ParseSbpParallelFromString(sbp_str, &sbp_parallel)); - *in_distribution->add_sbp_parallel() = sbp_parallel; - *out_distribution->add_sbp_parallel() = sbp_parallel; - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_USER_OP("hierarchical_parallel_cast_like") - .Input("in") - .Input("like") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - cfg::NdSbp* like_distribution = ctx->NdSbp4ArgNameAndIndex("like", 0); - const cfg::NdSbp& hint_distribution = ctx->NdSbpHint4InputArgNameAndIndex("like", 0); - *in_distribution = hint_distribution; - *out_distribution = hint_distribution; - *like_distribution = hint_distribution; - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/* static */ Maybe HierarchicalParallelCastOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe HierarchicalParallelCastOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe HierarchicalParallelCastOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe HierarchicalParallelCastOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + const auto& conf = ctx->user_op_conf().attr>("nd_sbp"); + CHECK_EQ_OR_RETURN(conf.size(), parallel_hierarchy.NumAxes()); + for (const std::string& sbp_str : conf) { + cfg::SbpParallel sbp_parallel; + CHECK_OR_RETURN(ParseSbpParallelFromString(sbp_str, &sbp_parallel)); + *in_distribution->add_sbp_parallel() = sbp_parallel; + *out_distribution->add_sbp_parallel() = sbp_parallel; + } + return Maybe::Ok(); +} + +/* static */ Maybe HierarchicalParallelCastOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe HierarchicalParallelCastLikeOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe HierarchicalParallelCastLikeOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe HierarchicalParallelCastLikeOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe HierarchicalParallelCastLikeOp::InferNdSbp( + user_op::InferNdSbpFnContext* ctx) { + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + cfg::NdSbp* like_distribution = ctx->NdSbp4ArgNameAndIndex("like", 0); + const cfg::NdSbp& hint_distribution = ctx->NdSbpHint4InputArgNameAndIndex("like", 0); + *in_distribution = hint_distribution; + *out_distribution = hint_distribution; + *like_distribution = hint_distribution; + return Maybe::Ok(); +} + +/* static */ Maybe HierarchicalParallelCastLikeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("hierarchical_parallel_cast") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { diff --git a/oneflow/user/ops/identity_op.cpp b/oneflow/user/ops/identity_op.cpp index 2e67cefc4c6..538abeb5dde 100644 --- a/oneflow/user/ops/identity_op.cpp +++ b/oneflow/user/ops/identity_op.cpp @@ -14,37 +14,36 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/* static */ Maybe IdentityOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("identity") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe IdentityOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe IdentityOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe IdentityOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("identity") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -62,6 +61,4 @@ REGISTER_USER_OP_GRAD("identity") return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/image_batch_align_op.cpp b/oneflow/user/ops/image_batch_align_op.cpp index d6eaa08396c..0563281485b 100644 --- a/oneflow/user/ops/image_batch_align_op.cpp +++ b/oneflow/user/ops/image_batch_align_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -27,70 +28,71 @@ bool PowerOfTwo(T x) { } // namespace -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("image_batch_align") - .Input("in") - .Output("out") - .Attr("shape") - .Attr("data_type") - .Attr("alignment") - .Attr("dynamic_out") - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& def, - const user_op::UserOpConfWrapper& conf) -> Maybe { - bool check_failed = false; - std::stringstream err; - err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); - const Shape& shape = conf.attr("shape"); - if (shape.NumAxes() != 3) { - err << ", shape: " << shape.ToString() << " (image shape must has 3 axes)"; - check_failed = true; - } - DataType data_type = conf.attr("data_type"); - if (data_type != DataType::kUInt8 && data_type != DataType::kFloat) { - err << ", data_type: " << data_type << " (only support kUInt8 and kFloat for now)"; - check_failed = true; - } - int32_t alignment = conf.attr("alignment"); - if (alignment < 0) { - err << ", alignment: " << alignment << " (alignment must be greater than or equal to 0)"; - check_failed = true; - } else if (alignment != 0 && !PowerOfTwo(alignment)) { - err << ", alignment: " << alignment - << " (alignment must be power of 2 when it's not equal to 0)"; - check_failed = true; - } - if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } - return Maybe::Ok(); - }) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1); - const Shape& shape_attr = ctx->Attr("shape"); - const bool dynamic_out = ctx->Attr("dynamic_out"); - DimVector dim_vec(shape_attr.NumAxes() + 1); - dim_vec.at(0) = in_desc.shape().elem_cnt(); - FOR_RANGE(int64_t, i, 0, shape_attr.NumAxes()) { dim_vec.at(i + 1) = shape_attr.At(i); } - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_shape() = Shape(dim_vec); - out_desc->set_is_dynamic(dynamic_out); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetOutputArgModifyFn([](user_op::GetOutputArgModifier GetOutputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> Maybe { - user_op::OutputArgModifier* out_modifier = GetOutputArgModifierFn("out", 0); - CHECK_OR_RETURN(out_modifier != nullptr); - out_modifier->set_header_infered_before_compute(false); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_data_type() = ctx->Attr("data_type"); - return Maybe::Ok(); - }); +/* static */ Maybe ImageBatchAlignOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1); + const Shape& shape_attr = ctx->Attr("shape"); + const bool dynamic_out = ctx->Attr("dynamic_out"); + DimVector dim_vec(shape_attr.NumAxes() + 1); + dim_vec.at(0) = in_desc.shape().elem_cnt(); + FOR_RANGE(int64_t, i, 0, shape_attr.NumAxes()) { dim_vec.at(i + 1) = shape_attr.At(i); } + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_shape() = Shape(dim_vec); + out_desc->set_is_dynamic(dynamic_out); + return Maybe::Ok(); +} + +/*static*/ Maybe ImageBatchAlignOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ImageBatchAlignOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ImageBatchAlignOp::ModifyOutputArg( + const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::OutputArgModifier* out_modifier = GetOutputArgModifierFn("out", 0); + CHECK_OR_RETURN(out_modifier != nullptr); + out_modifier->set_header_infered_before_compute(false); + return Maybe::Ok(); +} + +/* static */ Maybe ImageBatchAlignOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + bool check_failed = false; + std::stringstream err; + err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); + const Shape& shape = conf.attr("shape"); + if (shape.NumAxes() != 3) { + err << ", shape: " << shape.ToString() << " (image shape must has 3 axes)"; + check_failed = true; + } + DataType data_type = conf.attr("data_type"); + if (data_type != DataType::kUInt8 && data_type != DataType::kFloat) { + err << ", data_type: " << data_type << " (only support kUInt8 and kFloat for now)"; + check_failed = true; + } + int32_t alignment = conf.attr("alignment"); + if (alignment < 0) { + err << ", alignment: " << alignment << " (alignment must be greater than or equal to 0)"; + check_failed = true; + } else if (alignment != 0 && !PowerOfTwo(alignment)) { + err << ", alignment: " << alignment + << " (alignment must be power of 2 when it's not equal to 0)"; + check_failed = true; + } + if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } + return Maybe::Ok(); +} + +/* static */ Maybe ImageBatchAlignOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_data_type() = ctx->Attr("data_type"); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/image_decode_op.cpp b/oneflow/user/ops/image_decode_op.cpp index 7db39a67ae5..cd308ce528e 100644 --- a/oneflow/user/ops/image_decode_op.cpp +++ b/oneflow/user/ops/image_decode_op.cpp @@ -14,50 +14,53 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("image_decode") - .Input("in") - .Output("out") - .Attr("color_space", "BGR") - .Attr("data_type", DataType::kUInt8) - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& def, - const user_op::UserOpConfWrapper& conf) -> Maybe { - bool check_failed = false; - std::stringstream err; - err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); - const std::string& color_space = conf.attr("color_space"); - if (color_space != "BGR" && color_space != "RGB" && color_space != "GRAY") { - err << ", color_space: " << color_space - << " (color_space can only be one of BGR, RGB and GRAY)"; - check_failed = true; - } - DataType data_type = conf.attr("data_type"); - if (data_type != DataType::kUInt8 && data_type != DataType::kFloat) { - err << ", data_type: " << data_type << " (only support kUInt8 and kFloat for now)"; - check_failed = true; - } - if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } - return Maybe::Ok(); - }) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1 && in_desc.shape().At(0) >= 1); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_shape() = in_desc.shape(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_data_type() = DataType::kTensorBuffer; - return Maybe::Ok(); - }); +/* static */ Maybe ImageDecodeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1 && in_desc.shape().At(0) >= 1); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_shape() = in_desc.shape(); + return Maybe::Ok(); +} + +/*static*/ Maybe ImageDecodeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ImageDecodeOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ImageDecodeOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + bool check_failed = false; + std::stringstream err; + err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); + const std::string& color_space = conf.attr("color_space"); + if (color_space != "BGR" && color_space != "RGB" && color_space != "GRAY") { + err << ", color_space: " << color_space + << " (color_space can only be one of BGR, RGB and GRAY)"; + check_failed = true; + } + DataType data_type = conf.attr("data_type"); + if (data_type != DataType::kUInt8 && data_type != DataType::kFloat) { + err << ", data_type: " << data_type << " (only support kUInt8 and kFloat for now)"; + check_failed = true; + } + if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } + return Maybe::Ok(); +} + +/* static */ Maybe ImageDecodeOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_data_type() = DataType::kTensorBuffer; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/image_object_preprocess_ops.cpp b/oneflow/user/ops/image_object_preprocess_ops.cpp index 979e3ebf30c..5fd2cb99f38 100644 --- a/oneflow/user/ops/image_object_preprocess_ops.cpp +++ b/oneflow/user/ops/image_object_preprocess_ops.cpp @@ -14,209 +14,243 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { -Maybe GetSbp(user_op::SbpContext* ctx) { +Maybe ImageObjectGetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); return Maybe::Ok(); } } // namespace -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("image_flip") - .Input("in") - .Input("flip_code") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - CHECK_EQ_OR_RETURN(in_desc.shape().NumAxes(), 1); - const int N = in_desc.shape().elem_cnt(); - - const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); - CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N); - - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetSbp) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - CHECK_EQ_OR_RETURN(in_desc.data_type(), DataType::kTensorBuffer); - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("object_bbox_flip") - .Input("bbox") - .Input("image_size") - .Input("flip_code") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& bbox_desc = ctx->InputTensorDesc("bbox", 0); - CHECK_EQ_OR_RETURN(bbox_desc.shape().NumAxes(), 1); - const int N = bbox_desc.shape().elem_cnt(); - - const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); - CHECK_EQ_OR_RETURN(image_size_desc.shape().elem_cnt(), N * 2); - - const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); - CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N); - - *ctx->OutputShape("out", 0) = ctx->InputShape("bbox", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("bbox", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetSbp) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& bbox_desc = ctx->InputTensorDesc("bbox", 0); - CHECK_EQ_OR_RETURN(bbox_desc.data_type(), DataType::kTensorBuffer); - const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); - CHECK_EQ_OR_RETURN(image_size_desc.data_type(), DataType::kInt32); - const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); - CHECK_EQ_OR_RETURN(flip_code_desc.data_type(), DataType::kInt8); - *ctx->OutputDType("out", 0) = ctx->InputDType("bbox", 0); - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("object_bbox_scale") - .Input("bbox") - .Input("scale") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& bbox_desc = ctx->InputTensorDesc("bbox", 0); - CHECK_EQ_OR_RETURN(bbox_desc.shape().NumAxes(), 1); - const int N = bbox_desc.shape().elem_cnt(); - - const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); - CHECK_EQ_OR_RETURN(scale_desc.shape().elem_cnt(), N * 2); - - *ctx->OutputShape("out", 0) = ctx->InputShape("bbox", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("bbox", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetSbp) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& bbox_desc = ctx->InputTensorDesc("bbox", 0); - CHECK_EQ_OR_RETURN(bbox_desc.data_type(), DataType::kTensorBuffer); - const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); - CHECK_EQ_OR_RETURN(scale_desc.data_type(), DataType::kFloat); - *ctx->OutputDType("out", 0) = ctx->InputDType("bbox", 0); - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("object_segmentation_polygon_flip") - .Input("poly") - .Input("image_size") - .Input("flip_code") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); - CHECK_EQ_OR_RETURN(poly_desc.shape().NumAxes(), 1); - const int N = poly_desc.shape().elem_cnt(); - - const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); - CHECK_EQ_OR_RETURN(image_size_desc.shape().elem_cnt(), N * 2); - - const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); - CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N); - - *ctx->OutputShape("out", 0) = ctx->InputShape("poly", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("poly", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetSbp) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); - CHECK_EQ_OR_RETURN(poly_desc.data_type(), DataType::kTensorBuffer); - const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); - CHECK_EQ_OR_RETURN(image_size_desc.data_type(), DataType::kInt32); - const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); - CHECK_EQ_OR_RETURN(flip_code_desc.data_type(), DataType::kInt8); - *ctx->OutputDType("out", 0) = ctx->InputDType("poly", 0); - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("object_segmentation_polygon_scale") - .Input("poly") - .Input("scale") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); - CHECK_EQ_OR_RETURN(poly_desc.shape().NumAxes(), 1); - const int N = poly_desc.shape().elem_cnt(); - - const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); - CHECK_EQ_OR_RETURN(scale_desc.shape().elem_cnt(), N * 2); - - *ctx->OutputShape("out", 0) = ctx->InputShape("poly", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("poly", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetSbp) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); - CHECK_EQ_OR_RETURN(poly_desc.data_type(), DataType::kTensorBuffer); - const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); - CHECK_EQ_OR_RETURN(scale_desc.data_type(), DataType::kFloat); - *ctx->OutputDType("out", 0) = ctx->InputDType("poly", 0); - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("image_normalize") - .Input("in") - .Attr>("std") - .Attr>("mean") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - CHECK_EQ_OR_RETURN(in_desc.shape().NumAxes(), 1); - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetSbp) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - CHECK_EQ_OR_RETURN(in_desc.data_type(), DataType::kTensorBuffer); - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("object_segmentation_polygon_to_mask") - .Input("poly") - .Input("poly_index") - .Input("image_size") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); - CHECK_EQ_OR_RETURN(poly_desc.shape().NumAxes(), 1); - const int N = poly_desc.shape().elem_cnt(); - - const user_op::TensorDesc& poly_index_desc = ctx->InputTensorDesc("poly_index", 0); - CHECK_EQ_OR_RETURN(poly_index_desc.shape().NumAxes(), 1); - CHECK_EQ_OR_RETURN(poly_index_desc.shape().elem_cnt(), N); - - const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); - CHECK_EQ_OR_RETURN(image_size_desc.shape().elem_cnt(), N * 2); - - *ctx->OutputShape("out", 0) = ctx->InputShape("poly", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("poly", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetSbp) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); - CHECK_EQ_OR_RETURN(poly_desc.data_type(), DataType::kTensorBuffer); - const user_op::TensorDesc& poly_index_desc = ctx->InputTensorDesc("poly_index", 0); - CHECK_EQ_OR_RETURN(poly_index_desc.data_type(), DataType::kTensorBuffer); - const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); - CHECK_EQ_OR_RETURN(image_size_desc.data_type(), DataType::kInt32); - *ctx->OutputDType("out", 0) = ctx->InputDType("poly", 0); - return Maybe::Ok(); - }); +/* static */ Maybe ImageFlipOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + CHECK_EQ_OR_RETURN(in_desc.shape().NumAxes(), 1); + const int N = in_desc.shape().elem_cnt(); + + const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); + CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N); + + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe ImageFlipOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ImageFlipOp::GetSbp(user_op::SbpContext* ctx) { + return ImageObjectGetSbp(ctx); +} + +/* static */ Maybe ImageFlipOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + CHECK_EQ_OR_RETURN(in_desc.data_type(), DataType::kTensorBuffer); + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe ObjectBboxFlipOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& bbox_desc = ctx->InputTensorDesc("bbox", 0); + CHECK_EQ_OR_RETURN(bbox_desc.shape().NumAxes(), 1); + const int N = bbox_desc.shape().elem_cnt(); + + const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); + CHECK_EQ_OR_RETURN(image_size_desc.shape().elem_cnt(), N * 2); + + const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); + CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N); + + *ctx->OutputShape("out", 0) = ctx->InputShape("bbox", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("bbox", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe ObjectBboxFlipOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ObjectBboxFlipOp::GetSbp(user_op::SbpContext* ctx) { + return ImageObjectGetSbp(ctx); +} + +/* static */ Maybe ObjectBboxFlipOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& bbox_desc = ctx->InputTensorDesc("bbox", 0); + CHECK_EQ_OR_RETURN(bbox_desc.data_type(), DataType::kTensorBuffer); + const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); + CHECK_EQ_OR_RETURN(image_size_desc.data_type(), DataType::kInt32); + const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); + CHECK_EQ_OR_RETURN(flip_code_desc.data_type(), DataType::kInt8); + *ctx->OutputDType("out", 0) = ctx->InputDType("bbox", 0); + return Maybe::Ok(); +} + +/* static */ Maybe ObjectBboxScaleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& bbox_desc = ctx->InputTensorDesc("bbox", 0); + CHECK_EQ_OR_RETURN(bbox_desc.shape().NumAxes(), 1); + const int N = bbox_desc.shape().elem_cnt(); + + const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); + CHECK_EQ_OR_RETURN(scale_desc.shape().elem_cnt(), N * 2); + + *ctx->OutputShape("out", 0) = ctx->InputShape("bbox", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("bbox", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe ObjectBboxScaleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ObjectBboxScaleOp::GetSbp(user_op::SbpContext* ctx) { + return ImageObjectGetSbp(ctx); +} + +/* static */ Maybe ObjectBboxScaleOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& bbox_desc = ctx->InputTensorDesc("bbox", 0); + CHECK_EQ_OR_RETURN(bbox_desc.data_type(), DataType::kTensorBuffer); + const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); + CHECK_EQ_OR_RETURN(scale_desc.data_type(), DataType::kFloat); + *ctx->OutputDType("out", 0) = ctx->InputDType("bbox", 0); + return Maybe::Ok(); +} + +/* static */ Maybe ObjectSegmentationPolygonFlipOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); + CHECK_EQ_OR_RETURN(poly_desc.shape().NumAxes(), 1); + const int N = poly_desc.shape().elem_cnt(); + + const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); + CHECK_EQ_OR_RETURN(image_size_desc.shape().elem_cnt(), N * 2); + + const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); + CHECK_EQ_OR_RETURN(flip_code_desc.shape().elem_cnt(), N); + + *ctx->OutputShape("out", 0) = ctx->InputShape("poly", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("poly", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe ObjectSegmentationPolygonFlipOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ObjectSegmentationPolygonFlipOp::GetSbp(user_op::SbpContext* ctx) { + return ImageObjectGetSbp(ctx); +} + +/* static */ Maybe ObjectSegmentationPolygonFlipOp::InferDataType( + user_op::InferContext* ctx) { + const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); + CHECK_EQ_OR_RETURN(poly_desc.data_type(), DataType::kTensorBuffer); + const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); + CHECK_EQ_OR_RETURN(image_size_desc.data_type(), DataType::kInt32); + const user_op::TensorDesc& flip_code_desc = ctx->InputTensorDesc("flip_code", 0); + CHECK_EQ_OR_RETURN(flip_code_desc.data_type(), DataType::kInt8); + *ctx->OutputDType("out", 0) = ctx->InputDType("poly", 0); + return Maybe::Ok(); +} + +/* static */ Maybe ObjectSegmentationPolygonScaleOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); + CHECK_EQ_OR_RETURN(poly_desc.shape().NumAxes(), 1); + const int N = poly_desc.shape().elem_cnt(); + + const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); + CHECK_EQ_OR_RETURN(scale_desc.shape().elem_cnt(), N * 2); + + *ctx->OutputShape("out", 0) = ctx->InputShape("poly", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("poly", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe ObjectSegmentationPolygonScaleOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ObjectSegmentationPolygonScaleOp::GetSbp(user_op::SbpContext* ctx) { + return ImageObjectGetSbp(ctx); +} + +/* static */ Maybe ObjectSegmentationPolygonScaleOp::InferDataType( + user_op::InferContext* ctx) { + const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); + CHECK_EQ_OR_RETURN(poly_desc.data_type(), DataType::kTensorBuffer); + const user_op::TensorDesc& scale_desc = ctx->InputTensorDesc("scale", 0); + CHECK_EQ_OR_RETURN(scale_desc.data_type(), DataType::kFloat); + *ctx->OutputDType("out", 0) = ctx->InputDType("poly", 0); + return Maybe::Ok(); +} + +/* static */ Maybe ImageNormalizeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + CHECK_EQ_OR_RETURN(in_desc.shape().NumAxes(), 1); + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe ImageNormalizeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ImageNormalizeOp::GetSbp(user_op::SbpContext* ctx) { + return ImageObjectGetSbp(ctx); +} + +/* static */ Maybe ImageNormalizeOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + CHECK_EQ_OR_RETURN(in_desc.data_type(), DataType::kTensorBuffer); + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe ObjectSegmentationPolygonToMaskOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); + CHECK_EQ_OR_RETURN(poly_desc.shape().NumAxes(), 1); + const int N = poly_desc.shape().elem_cnt(); + + const user_op::TensorDesc& poly_index_desc = ctx->InputTensorDesc("poly_index", 0); + CHECK_EQ_OR_RETURN(poly_index_desc.shape().NumAxes(), 1); + CHECK_EQ_OR_RETURN(poly_index_desc.shape().elem_cnt(), N); + + const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); + CHECK_EQ_OR_RETURN(image_size_desc.shape().elem_cnt(), N * 2); + + *ctx->OutputShape("out", 0) = ctx->InputShape("poly", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("poly", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe ObjectSegmentationPolygonToMaskOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ObjectSegmentationPolygonToMaskOp::GetSbp(user_op::SbpContext* ctx) { + return ImageObjectGetSbp(ctx); +} + +/* static */ Maybe ObjectSegmentationPolygonToMaskOp::InferDataType( + user_op::InferContext* ctx) { + const user_op::TensorDesc& poly_desc = ctx->InputTensorDesc("poly", 0); + CHECK_EQ_OR_RETURN(poly_desc.data_type(), DataType::kTensorBuffer); + const user_op::TensorDesc& poly_index_desc = ctx->InputTensorDesc("poly_index", 0); + CHECK_EQ_OR_RETURN(poly_index_desc.data_type(), DataType::kTensorBuffer); + const user_op::TensorDesc& image_size_desc = ctx->InputTensorDesc("image_size", 0); + CHECK_EQ_OR_RETURN(image_size_desc.data_type(), DataType::kInt32); + *ctx->OutputDType("out", 0) = ctx->InputDType("poly", 0); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/image_preprocess_ops.cpp b/oneflow/user/ops/image_preprocess_ops.cpp index a244a4ecab4..91bb5cee58f 100644 --- a/oneflow/user/ops/image_preprocess_ops.cpp +++ b/oneflow/user/ops/image_preprocess_ops.cpp @@ -17,236 +17,224 @@ limitations under the License. #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/user/image/image_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("crop_mirror_normalize_from_tensorbuffer") - .Input("in") - .OptionalInput("mirror") - .Output("out") - .Attr("color_space", "BGR") - .Attr("output_layout", "NCHW") - .Attr>("mean", {0.0}) - .Attr>("std", {1.0}) - .Attr("crop_h", 0) - .Attr("crop_w", 0) - .Attr("crop_pos_x", 0.5) - .Attr("crop_pos_y", 0.5) - .Attr("output_dtype", DataType::kFloat) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - bool has_mirror = ctx->has_input("mirror", 0); - if (has_mirror) { - const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc("mirror", 0); - CHECK_OR_RETURN(mirror_tensor.shape().NumAxes() == 1 - && in_tensor.shape().At(0) == mirror_tensor.shape().At(0)); - } - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - int64_t N = in_tensor.shape().At(0); - int64_t H = ctx->Attr("crop_h"); - int64_t W = ctx->Attr("crop_w"); - std::string color_space = ctx->Attr("color_space"); - int64_t C = ImageUtil::IsColor(color_space) ? 3 : 1; - - CHECK_OR_RETURN(H != 0 && W != 0); - CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1); - std::string output_layout = ctx->Attr("output_layout"); - if (output_layout == "NCHW") { - *out_tensor->mut_shape() = Shape({N, C, H, W}); - } else if (output_layout == "NHWC") { - *out_tensor->mut_shape() = Shape({N, H, W, C}); - } else { - return Error::CheckFailedError() - << "output_layout: " << output_layout << " is not supported"; - } - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - CHECK_EQ_OR_RETURN(in_tensor.data_type(), DataType::kTensorBuffer); - bool has_mirror = ctx->has_input("mirror", 0); - if (has_mirror) { - const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc("mirror", 0); - CHECK_EQ_OR_RETURN(mirror_tensor.data_type(), DataType::kInt8); - } +/* static */ Maybe CropMirrorNormalizeFromTensorbufferOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + bool has_mirror = ctx->has_input("mirror", 0); + if (has_mirror) { + const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc("mirror", 0); + CHECK_OR_RETURN(mirror_tensor.shape().NumAxes() == 1 + && in_tensor.shape().At(0) == mirror_tensor.shape().At(0)); + } + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + int64_t N = in_tensor.shape().At(0); + int64_t H = ctx->Attr("crop_h"); + int64_t W = ctx->Attr("crop_w"); + std::string color_space = ctx->Attr("color_space"); + int64_t C = ImageUtil::IsColor(color_space) ? 3 : 1; - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - DataType output_dtype = ctx->Attr("output_dtype"); - CHECK_EQ_OR_RETURN(output_dtype, - DataType::kFloat); // only support float now; for float16 in future - *out_tensor->mut_data_type() = output_dtype; - - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_USER_OP("crop_mirror_normalize_from_uint8") - .Input("in") - .OptionalInput("mirror") - .Output("out") - .Attr("color_space", "BGR") - .Attr("output_layout", "NCHW") - .Attr>("mean", {0.0}) - .Attr>("std", {1.0}) - .Attr("crop_h", 0) - .Attr("crop_w", 0) - .Attr("crop_pos_x", 0.5) - .Attr("crop_pos_y", 0.5) - .Attr("output_dtype", DataType::kFloat) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - bool has_mirror = ctx->has_input("mirror", 0); - if (has_mirror) { - const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc("mirror", 0); - CHECK_OR_RETURN(mirror_tensor.shape().NumAxes() == 1 - && in_tensor.shape().At(0) == mirror_tensor.shape().At(0)); - } - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - int64_t N = in_tensor.shape().At(0); - int64_t H = ctx->Attr("crop_h"); - int64_t W = ctx->Attr("crop_w"); - std::string color_space = ctx->Attr("color_space"); - int64_t C = ImageUtil::IsColor(color_space) ? 3 : 1; - CHECK_EQ_OR_RETURN(in_tensor.shape().NumAxes(), 4); // {N, H, W, C} - CHECK_EQ_OR_RETURN(in_tensor.shape().At(3), C); - if (H == 0 || W == 0) { - H = in_tensor.shape().At(1); - W = in_tensor.shape().At(2); - } else { - H = std::min(H, in_tensor.shape().At(1)); - W = std::min(W, in_tensor.shape().At(2)); - } - std::string output_layout = ctx->Attr("output_layout"); - if (output_layout == "NCHW") { - *out_tensor->mut_shape() = Shape({N, C, H, W}); - } else if (output_layout == "NHWC") { - *out_tensor->mut_shape() = Shape({N, H, W, C}); - } else { - return Error::CheckFailedError() - << "output_layout: " << output_layout << " is not supported"; - } + CHECK_OR_RETURN(H != 0 && W != 0); + CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1); + std::string output_layout = ctx->Attr("output_layout"); + if (output_layout == "NCHW") { + *out_tensor->mut_shape() = Shape({N, C, H, W}); + } else if (output_layout == "NHWC") { + *out_tensor->mut_shape() = Shape({N, H, W, C}); + } else { + return Error::CheckFailedError() << "output_layout: " << output_layout << " is not supported"; + } + return Maybe::Ok(); +} - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - CHECK_EQ_OR_RETURN(in_tensor.data_type(), DataType::kUInt8); - bool has_mirror = ctx->has_input("mirror", 0); - if (has_mirror) { - const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc("mirror", 0); - CHECK_EQ_OR_RETURN(mirror_tensor.data_type(), DataType::kInt8); - } - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - DataType output_dtype = ctx->Attr("output_dtype"); - CHECK_EQ_OR_RETURN(output_dtype, - DataType::kFloat); // only support float now; for float16 in future - *out_tensor->mut_data_type() = output_dtype; - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("coin_flip") - .Output("out") - .Attr("probability", 0.5) - .Attr("batch_size") - .Attr("seed", -1) - .Attr("has_seed", false) - .Attr>("nd_sbp") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - int64_t batch_size = ctx->Attr("batch_size"); - *out_tensor->mut_shape() = Shape({batch_size}); - return Maybe::Ok(); - }) - .SetPhysicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - int64_t batch_size = ctx->Attr("batch_size"); - const ParallelContext& parallel_ctx = ctx->parallel_ctx(); - const cfg::SbpParallel& out_sbp = ctx->SbpParallel4ArgNameAndIndex("out", 0); - if (parallel_ctx.parallel_num() > 1 && out_sbp.has_split_parallel()) { - BalancedSplitter bs(batch_size, parallel_ctx.parallel_num()); - *out_tensor->mut_shape() = Shape({bs.At(parallel_ctx.parallel_id()).size()}); - } else { - *out_tensor->mut_shape() = Shape({batch_size}); - } - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const Shape& hierarchy = ctx->parallel_hierarchy(); - cfg::NdSbp* output_dist = ctx->NdSbp4ArgNameAndIndex("out", 0); - // the input may be produced by tick which should be broadcast parallel dist - std::vector inputs_dist; - for (const auto& arg_pair : ctx->inputs()) { - inputs_dist.emplace_back(ctx->NdSbp4ArgNameAndIndex(arg_pair.first, arg_pair.second)); +/*static*/ Maybe CropMirrorNormalizeFromTensorbufferOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CropMirrorNormalizeFromTensorbufferOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe CropMirrorNormalizeFromTensorbufferOp::InferDataType( + user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + CHECK_EQ_OR_RETURN(in_tensor.data_type(), DataType::kTensorBuffer); + bool has_mirror = ctx->has_input("mirror", 0); + if (has_mirror) { + const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc("mirror", 0); + CHECK_EQ_OR_RETURN(mirror_tensor.data_type(), DataType::kInt8); + } + + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + DataType output_dtype = ctx->Attr("output_dtype"); + CHECK_EQ_OR_RETURN(output_dtype, + DataType::kFloat); // only support float now; for float16 in future + *out_tensor->mut_data_type() = output_dtype; + + return Maybe::Ok(); +} + +/* static */ Maybe CropMirrorNormalizeFromUint8Op::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + bool has_mirror = ctx->has_input("mirror", 0); + if (has_mirror) { + const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc("mirror", 0); + CHECK_OR_RETURN(mirror_tensor.shape().NumAxes() == 1 + && in_tensor.shape().At(0) == mirror_tensor.shape().At(0)); + } + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + int64_t N = in_tensor.shape().At(0); + int64_t H = ctx->Attr("crop_h"); + int64_t W = ctx->Attr("crop_w"); + std::string color_space = ctx->Attr("color_space"); + int64_t C = ImageUtil::IsColor(color_space) ? 3 : 1; + CHECK_EQ_OR_RETURN(in_tensor.shape().NumAxes(), 4); // {N, H, W, C} + CHECK_EQ_OR_RETURN(in_tensor.shape().At(3), C); + if (H == 0 || W == 0) { + H = in_tensor.shape().At(1); + W = in_tensor.shape().At(2); + } else { + H = std::min(H, in_tensor.shape().At(1)); + W = std::min(W, in_tensor.shape().At(2)); + } + std::string output_layout = ctx->Attr("output_layout"); + if (output_layout == "NCHW") { + *out_tensor->mut_shape() = Shape({N, C, H, W}); + } else if (output_layout == "NHWC") { + *out_tensor->mut_shape() = Shape({N, H, W, C}); + } else { + return Error::CheckFailedError() << "output_layout: " << output_layout << " is not supported"; + } + + return Maybe::Ok(); +} + +/*static*/ Maybe CropMirrorNormalizeFromUint8Op::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CropMirrorNormalizeFromUint8Op::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe CropMirrorNormalizeFromUint8Op::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + CHECK_EQ_OR_RETURN(in_tensor.data_type(), DataType::kUInt8); + bool has_mirror = ctx->has_input("mirror", 0); + if (has_mirror) { + const user_op::TensorDesc& mirror_tensor = ctx->InputTensorDesc("mirror", 0); + CHECK_EQ_OR_RETURN(mirror_tensor.data_type(), DataType::kInt8); + } + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + DataType output_dtype = ctx->Attr("output_dtype"); + CHECK_EQ_OR_RETURN(output_dtype, + DataType::kFloat); // only support float now; for float16 in future + *out_tensor->mut_data_type() = output_dtype; + return Maybe::Ok(); +} + +/* static */ Maybe CoinFlipOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + int64_t batch_size = ctx->Attr("batch_size"); + *out_tensor->mut_shape() = Shape({batch_size}); + return Maybe::Ok(); +} + +/* static */ Maybe CoinFlipOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + int64_t batch_size = ctx->Attr("batch_size"); + const ParallelContext& parallel_ctx = ctx->parallel_ctx(); + const cfg::SbpParallel& out_sbp = ctx->SbpParallel4ArgNameAndIndex("out", 0); + if (parallel_ctx.parallel_num() > 1 && out_sbp.has_split_parallel()) { + BalancedSplitter bs(batch_size, parallel_ctx.parallel_num()); + *out_tensor->mut_shape() = Shape({bs.At(parallel_ctx.parallel_id()).size()}); + } else { + *out_tensor->mut_shape() = Shape({batch_size}); + } + return Maybe::Ok(); +} + +/* static */ Maybe CoinFlipOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("out", 0), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe CoinFlipOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + const Shape& hierarchy = ctx->parallel_hierarchy(); + cfg::NdSbp* output_dist = ctx->NdSbp4ArgNameAndIndex("out", 0); + // the input may be produced by tick which should be broadcast parallel dist + std::vector inputs_dist; + for (const auto& arg_pair : ctx->inputs()) { + inputs_dist.emplace_back(ctx->NdSbp4ArgNameAndIndex(arg_pair.first, arg_pair.second)); + } + const auto& dist_conf = ctx->user_op_conf().attr>("nd_sbp"); + if (dist_conf.size() == 0) { + FOR_RANGE(int, i, 0, hierarchy.NumAxes()) { + output_dist->add_sbp_parallel()->mutable_split_parallel()->set_axis(0); + for (auto* input_dist : inputs_dist) { + input_dist->add_sbp_parallel()->mutable_broadcast_parallel(); } - const auto& dist_conf = ctx->user_op_conf().attr>("nd_sbp"); - if (dist_conf.size() == 0) { - FOR_RANGE(int, i, 0, hierarchy.NumAxes()) { - output_dist->add_sbp_parallel()->mutable_split_parallel()->set_axis(0); - for (auto* input_dist : inputs_dist) { - input_dist->add_sbp_parallel()->mutable_broadcast_parallel(); - } - } - } else { - CHECK_EQ_OR_RETURN(dist_conf.size(), hierarchy.NumAxes()); - for (const std::string& sbp_str : dist_conf) { - cfg::SbpParallel sbp_parallel; - CHECK_OR_RETURN(ParseSbpParallelFromString(sbp_str, &sbp_parallel)); - CHECK_OR_RETURN( - (sbp_parallel.has_split_parallel() && sbp_parallel.split_parallel().axis() == 0) - || sbp_parallel.has_broadcast_parallel()); - *output_dist->add_sbp_parallel() = sbp_parallel; - for (auto* input_dist : inputs_dist) { - input_dist->add_sbp_parallel()->mutable_broadcast_parallel(); - } - } + } + } else { + CHECK_EQ_OR_RETURN(dist_conf.size(), hierarchy.NumAxes()); + for (const std::string& sbp_str : dist_conf) { + cfg::SbpParallel sbp_parallel; + CHECK_OR_RETURN(ParseSbpParallelFromString(sbp_str, &sbp_parallel)); + CHECK_OR_RETURN( + (sbp_parallel.has_split_parallel() && sbp_parallel.split_parallel().axis() == 0) + || sbp_parallel.has_broadcast_parallel()); + *output_dist->add_sbp_parallel() = sbp_parallel; + for (auto* input_dist : inputs_dist) { + input_dist->add_sbp_parallel()->mutable_broadcast_parallel(); } - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("out", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - *out_tensor->mut_data_type() = DataType::kInt8; - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("image_random_crop") - .Input("in") - .Output("out") - .Attr("num_attempts", 10) - .Attr("seed", -1) - .Attr("has_seed", false) - .Attr>("random_area", {0.08, 1.0}) - .Attr>("random_aspect_ratio", {0.75, 1.333333}) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - *out_tensor->mut_shape() = in_tensor.shape(); - *out_tensor->mut_is_dynamic() = in_tensor.is_dynamic(); - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::SplitForEachAxis) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); - CHECK_NOTNULL_OR_RETURN(in_modifier); - in_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - CHECK_OR_RETURN(in_tensor.data_type() == DataType::kTensorBuffer); - *ctx->OutputDType("out", 0) = in_tensor.data_type(); - return Maybe::Ok(); - }); + } + } + return Maybe::Ok(); +} + +/* static */ Maybe CoinFlipOp::InferDataType(user_op::InferContext* ctx) { + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + *out_tensor->mut_data_type() = DataType::kInt8; + return Maybe::Ok(); +} + +/* static */ Maybe ImageRandomCropOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + *out_tensor->mut_shape() = in_tensor.shape(); + *out_tensor->mut_is_dynamic() = in_tensor.is_dynamic(); + return Maybe::Ok(); +} + +/*static*/ Maybe ImageRandomCropOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ImageRandomCropOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); +} + +/* static */ Maybe ImageRandomCropOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); + CHECK_NOTNULL_OR_RETURN(in_modifier); + in_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe ImageRandomCropOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + CHECK_OR_RETURN(in_tensor.data_type() == DataType::kTensorBuffer); + *ctx->OutputDType("out", 0) = in_tensor.data_type(); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/image_resize_ops.cpp b/oneflow/user/ops/image_resize_ops.cpp index eb030c17531..fe6f351ecaf 100644 --- a/oneflow/user/ops/image_resize_ops.cpp +++ b/oneflow/user/ops/image_resize_ops.cpp @@ -15,132 +15,130 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/image/image_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("image_resize_to_fixed") - .Input("in") - .Output("out") - .Output("scale") - .Attr("target_width", 0) - .Attr("target_height", 0) - .Attr("channels", 3) - .Attr("data_type", DataType::kUInt8) - .Attr("interpolation_type", "bilinear") - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& def, - const user_op::UserOpConfWrapper& conf) -> Maybe { - bool check_failed = false; - std::ostringstream err; - err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); - int64_t target_width = conf.attr("target_width"); - int64_t target_height = conf.attr("target_height"); - if (target_width <= 0 || target_height <= 0) { - err << ", target_width: " << target_width << ", target_height: " << target_height; - check_failed = true; - } - int64_t channels = conf.attr("channels"); - if (channels != 1 && channels != 3) { - err << ", channels: " << channels << " (channels can only be 1 or 3)"; - check_failed = true; - } - DataType data_type = conf.attr("data_type"); - if (data_type != DataType::kUInt8 && data_type != DataType::kFloat) { - err << ", data_type: " << data_type << " (only support kUInt8 and kFloat for now)"; - check_failed = true; - } - const std::string& interp_type = conf.attr("interpolation_type"); - if (!CheckInterpolationValid(interp_type, err)) { check_failed = true; } - if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } - return Maybe::Ok(); - }) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().elem_cnt() > 0); - int64_t batch_size = in_tensor.shape().elem_cnt(); - int64_t target_width = ctx->Attr("target_width"); - int64_t target_height = ctx->Attr("target_height"); - int64_t channels = ctx->Attr("channels"); - - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - *out_tensor->mut_shape() = Shape({batch_size, target_height, target_width, channels}); - out_tensor->set_is_dynamic(in_tensor.is_dynamic()); - - user_op::TensorDesc* scale_tensor = ctx->OutputTensorDesc("scale", 0); - *scale_tensor->mut_shape() = Shape({batch_size, 2}); - scale_tensor->set_is_dynamic(in_tensor.is_dynamic()); - - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - CHECK_OR_RETURN(in_tensor.data_type() == DataType::kTensorBuffer); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - *out_tensor->mut_data_type() = ctx->Attr("data_type"); - user_op::TensorDesc* scale_tensor = ctx->OutputTensorDesc("scale", 0); - *scale_tensor->mut_data_type() = DataType::kFloat; - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("image_resize_keep_aspect_ratio") - .Input("in") - .Output("out") - .Output("size") - .Output("scale") - .Attr("target_size") - .Attr("min_size", 0) - .Attr("max_size", 0) - .Attr("resize_longer", false) - .Attr("interpolation_type", "bilinear") - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& def, - const user_op::UserOpConfWrapper& conf) -> Maybe { - bool check_failed = false; - std::ostringstream err; - err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); - const int32_t target_size = conf.attr("target_size"); - const int32_t max_size = conf.attr("max_size"); - if (target_size <= 0) { - err << ", target_size: " << target_size << " (target_size must be greater than 0)"; - check_failed = true; - } - if (max_size < target_size && max_size > 0) { - err << ", max_size: " << max_size - << " (max_size must be greater than target_size or equal to 0)"; - check_failed = true; - } - const std::string& interp_type = conf.attr("interpolation_type"); - if (!CheckInterpolationValid(interp_type, err)) { check_failed = true; } - if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } - return Maybe::Ok(); - }) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1 && in_desc.shape().At(0) > 0); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_shape() = in_desc.shape(); - user_op::TensorDesc* size_desc = ctx->OutputTensorDesc("size", 0); - *size_desc->mut_shape() = in_desc.shape(); - user_op::TensorDesc* scale_desc = ctx->OutputTensorDesc("scale", 0); - *scale_desc->mut_shape() = in_desc.shape(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_data_type() = DataType::kTensorBuffer; - user_op::TensorDesc* size_desc = ctx->OutputTensorDesc("size", 0); - *size_desc->mut_data_type() = DataType::kTensorBuffer; - user_op::TensorDesc* scale_desc = ctx->OutputTensorDesc("scale", 0); - *scale_desc->mut_data_type() = DataType::kTensorBuffer; - return Maybe::Ok(); - }); +/* static */ Maybe ImageResizeToFixedOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().elem_cnt() > 0); + int64_t batch_size = in_tensor.shape().elem_cnt(); + int64_t target_width = ctx->Attr("target_width"); + int64_t target_height = ctx->Attr("target_height"); + int64_t channels = ctx->Attr("channels"); + + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + *out_tensor->mut_shape() = Shape({batch_size, target_height, target_width, channels}); + out_tensor->set_is_dynamic(in_tensor.is_dynamic()); + + user_op::TensorDesc* scale_tensor = ctx->OutputTensorDesc("scale", 0); + *scale_tensor->mut_shape() = Shape({batch_size, 2}); + scale_tensor->set_is_dynamic(in_tensor.is_dynamic()); + + return Maybe::Ok(); +} + +/*static*/ Maybe ImageResizeToFixedOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ImageResizeToFixedOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ImageResizeToFixedOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + bool check_failed = false; + std::ostringstream err; + err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); + int64_t target_width = conf.attr("target_width"); + int64_t target_height = conf.attr("target_height"); + if (target_width <= 0 || target_height <= 0) { + err << ", target_width: " << target_width << ", target_height: " << target_height; + check_failed = true; + } + int64_t channels = conf.attr("channels"); + if (channels != 1 && channels != 3) { + err << ", channels: " << channels << " (channels can only be 1 or 3)"; + check_failed = true; + } + DataType data_type = conf.attr("data_type"); + if (data_type != DataType::kUInt8 && data_type != DataType::kFloat) { + err << ", data_type: " << data_type << " (only support kUInt8 and kFloat for now)"; + check_failed = true; + } + const std::string& interp_type = conf.attr("interpolation_type"); + if (!CheckInterpolationValid(interp_type, err)) { check_failed = true; } + if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } + return Maybe::Ok(); +} + +/* static */ Maybe ImageResizeToFixedOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + CHECK_OR_RETURN(in_tensor.data_type() == DataType::kTensorBuffer); + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + *out_tensor->mut_data_type() = ctx->Attr("data_type"); + user_op::TensorDesc* scale_tensor = ctx->OutputTensorDesc("scale", 0); + *scale_tensor->mut_data_type() = DataType::kFloat; + return Maybe::Ok(); +} + +/* static */ Maybe ImageResizeKeepAspectRatioOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1 && in_desc.shape().At(0) > 0); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_shape() = in_desc.shape(); + user_op::TensorDesc* size_desc = ctx->OutputTensorDesc("size", 0); + *size_desc->mut_shape() = in_desc.shape(); + user_op::TensorDesc* scale_desc = ctx->OutputTensorDesc("scale", 0); + *scale_desc->mut_shape() = in_desc.shape(); + return Maybe::Ok(); +} + +/*static*/ Maybe ImageResizeKeepAspectRatioOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ImageResizeKeepAspectRatioOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ImageResizeKeepAspectRatioOp::CheckAttr( + const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& conf) { + bool check_failed = false; + std::ostringstream err; + err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); + const int32_t target_size = conf.attr("target_size"); + const int32_t max_size = conf.attr("max_size"); + if (target_size <= 0) { + err << ", target_size: " << target_size << " (target_size must be greater than 0)"; + check_failed = true; + } + if (max_size < target_size && max_size > 0) { + err << ", max_size: " << max_size + << " (max_size must be greater than target_size or equal to 0)"; + check_failed = true; + } + const std::string& interp_type = conf.attr("interpolation_type"); + if (!CheckInterpolationValid(interp_type, err)) { check_failed = true; } + if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } + return Maybe::Ok(); +} + +/* static */ Maybe ImageResizeKeepAspectRatioOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_data_type() = DataType::kTensorBuffer; + user_op::TensorDesc* size_desc = ctx->OutputTensorDesc("size", 0); + *size_desc->mut_data_type() = DataType::kTensorBuffer; + user_op::TensorDesc* scale_desc = ctx->OutputTensorDesc("scale", 0); + *scale_desc->mut_data_type() = DataType::kTensorBuffer; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/image_target_resize_op.cpp b/oneflow/user/ops/image_target_resize_op.cpp index dcf6e78ce85..49d7db09479 100644 --- a/oneflow/user/ops/image_target_resize_op.cpp +++ b/oneflow/user/ops/image_target_resize_op.cpp @@ -14,59 +14,60 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("image_target_resize") - .Input("in") - .Output("out") - .Output("size") - .Output("scale") - .Attr("target_size") - .Attr("max_size") - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& def, - const user_op::UserOpConfWrapper& conf) -> Maybe { - bool check_failed = false; - std::stringstream err; - err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); - const int32_t target_size = conf.attr("target_size"); - const int32_t max_size = conf.attr("max_size"); - if (target_size <= 0) { - err << ", target_size: " << target_size << " (target_size must be greater than 0)"; - check_failed = true; - } - if (max_size < target_size) { - err << ", max_size: " << max_size << " (max_size must be greater than 0)"; - check_failed = true; - } - if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } - return Maybe::Ok(); - }) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1 && in_desc.shape().At(0) >= 1); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_shape() = in_desc.shape(); - user_op::TensorDesc* size_desc = ctx->OutputTensorDesc("size", 0); - *size_desc->mut_shape() = Shape({in_desc.shape().elem_cnt(), 2}); - user_op::TensorDesc* scale_desc = ctx->OutputTensorDesc("scale", 0); - *scale_desc->mut_shape() = Shape({in_desc.shape().elem_cnt(), 2}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_data_type() = DataType::kTensorBuffer; - user_op::TensorDesc* size_desc = ctx->OutputTensorDesc("size", 0); - *size_desc->mut_data_type() = DataType::kInt32; - user_op::TensorDesc* scale_desc = ctx->OutputTensorDesc("scale", 0); - *scale_desc->mut_data_type() = DataType::kFloat; - return Maybe::Ok(); - }); +/* static */ Maybe ImageTargetResizeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + CHECK_OR_RETURN(in_desc.shape().NumAxes() == 1 && in_desc.shape().At(0) >= 1); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_shape() = in_desc.shape(); + user_op::TensorDesc* size_desc = ctx->OutputTensorDesc("size", 0); + *size_desc->mut_shape() = Shape({in_desc.shape().elem_cnt(), 2}); + user_op::TensorDesc* scale_desc = ctx->OutputTensorDesc("scale", 0); + *scale_desc->mut_shape() = Shape({in_desc.shape().elem_cnt(), 2}); + return Maybe::Ok(); +} + +/*static*/ Maybe ImageTargetResizeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ImageTargetResizeOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ImageTargetResizeOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& conf) { + bool check_failed = false; + std::stringstream err; + err << "Illegal attr value for " << conf.op_type_name() << " op, op_name: " << conf.op_name(); + const int32_t target_size = conf.attr("target_size"); + const int32_t max_size = conf.attr("max_size"); + if (target_size <= 0) { + err << ", target_size: " << target_size << " (target_size must be greater than 0)"; + check_failed = true; + } + if (max_size < target_size) { + err << ", max_size: " << max_size << " (max_size must be greater than 0)"; + check_failed = true; + } + if (check_failed) { return oneflow::Error::CheckFailedError() << err.str(); } + return Maybe::Ok(); +} + +/* static */ Maybe ImageTargetResizeOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + CHECK_OR_RETURN(in_desc.data_type() == DataType::kTensorBuffer); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_data_type() = DataType::kTensorBuffer; + user_op::TensorDesc* size_desc = ctx->OutputTensorDesc("size", 0); + *size_desc->mut_data_type() = DataType::kInt32; + user_op::TensorDesc* scale_desc = ctx->OutputTensorDesc("scale", 0); + *scale_desc->mut_data_type() = DataType::kFloat; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/in_top_k_op.cpp b/oneflow/user/ops/in_top_k_op.cpp index dc3f1a5858f..6ee9b5592e4 100644 --- a/oneflow/user/ops/in_top_k_op.cpp +++ b/oneflow/user/ops/in_top_k_op.cpp @@ -14,38 +14,40 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("in_top_k") - .Input("targets") - .Input("predictions") - .Attr("k") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& targets = ctx->InputTensorDesc("targets", 0); - const user_op::TensorDesc& predictions = ctx->InputTensorDesc("predictions", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - CHECK_EQ_OR_RETURN(targets.shape().NumAxes(), 1); - CHECK_EQ_OR_RETURN(predictions.shape().NumAxes(), 2); - const bool is_dynamic = targets.is_dynamic(); - CHECK_EQ_OR_RETURN(is_dynamic, predictions.is_dynamic()); - out->set_is_dynamic(is_dynamic); - *out->mut_shape() = targets.shape(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& targets = ctx->InputTensorDesc("targets", 0); - CHECK_OR_RETURN(IsIndexDataType(targets.data_type())); - const user_op::TensorDesc& predictions = ctx->InputTensorDesc("predictions", 0); - CHECK_EQ_OR_RETURN(predictions.data_type(), DataType::kFloat); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - *out->mut_data_type() = kInt8; - return Maybe::Ok(); - }); +/* static */ Maybe InTopKOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& targets = ctx->InputTensorDesc("targets", 0); + const user_op::TensorDesc& predictions = ctx->InputTensorDesc("predictions", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + CHECK_EQ_OR_RETURN(targets.shape().NumAxes(), 1); + CHECK_EQ_OR_RETURN(predictions.shape().NumAxes(), 2); + const bool is_dynamic = targets.is_dynamic(); + CHECK_EQ_OR_RETURN(is_dynamic, predictions.is_dynamic()); + out->set_is_dynamic(is_dynamic); + *out->mut_shape() = targets.shape(); + return Maybe::Ok(); +} + +/*static*/ Maybe InTopKOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/* static */ Maybe InTopKOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); } + +/* static */ Maybe InTopKOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& targets = ctx->InputTensorDesc("targets", 0); + CHECK_OR_RETURN(IsIndexDataType(targets.data_type())); + const user_op::TensorDesc& predictions = ctx->InputTensorDesc("predictions", 0); + CHECK_EQ_OR_RETURN(predictions.data_type(), DataType::kFloat); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + *out->mut_data_type() = kInt8; + return Maybe::Ok(); +} + +} // namespace oneflow diff --git a/oneflow/user/ops/indexed_slices_reduce_sum_op.cpp b/oneflow/user/ops/indexed_slices_reduce_sum_op.cpp index e5e9c035b5a..5b61c8ff2ba 100644 --- a/oneflow/user/ops/indexed_slices_reduce_sum_op.cpp +++ b/oneflow/user/ops/indexed_slices_reduce_sum_op.cpp @@ -14,42 +14,47 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("indexed_slices_reduce_sum") - .Input("x_indices") - .Input("x_values") - .Output("y_indices") - .Output("y_values") - .Output("num_unique") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_indices = ctx->InputTensorDesc("x_indices", 0); - const user_op::TensorDesc& x_values = ctx->InputTensorDesc("x_values", 0); - CHECK_LT_OR_RETURN(x_indices.shape().NumAxes(), x_values.shape().NumAxes()); - FOR_RANGE(int64_t, i, 0, x_indices.shape().NumAxes()) { - CHECK_EQ_OR_RETURN(x_indices.shape().At(i), x_values.shape().At(i)); - } - - const int64_t n = x_indices.shape().elem_cnt(); - const int64_t m = x_values.shape().elem_cnt() / n; - user_op::TensorDesc* y_indices = ctx->OutputTensorDesc("y_indices", 0); - user_op::TensorDesc* y_values = ctx->OutputTensorDesc("y_values", 0); - *y_indices = x_indices; - *y_indices->mut_shape() = Shape({n}); - *y_values = x_values; - *y_values->mut_shape() = Shape({n, m}); - user_op::TensorDesc* num_unique = ctx->OutputTensorDesc("num_unique", 0); - *num_unique->mut_shape() = Shape({1}); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_indices = ctx->InputTensorDesc("x_indices", 0); - CHECK_OR_RETURN(IsIndexDataType(x_indices.data_type())); - user_op::TensorDesc* num_unique = ctx->OutputTensorDesc("num_unique", 0); - *num_unique->mut_data_type() = DataType::kInt64; - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/* static */ Maybe IndexedSlicesReduceSumOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& x_indices = ctx->InputTensorDesc("x_indices", 0); + const user_op::TensorDesc& x_values = ctx->InputTensorDesc("x_values", 0); + CHECK_LT_OR_RETURN(x_indices.shape().NumAxes(), x_values.shape().NumAxes()); + FOR_RANGE(int64_t, i, 0, x_indices.shape().NumAxes()) { + CHECK_EQ_OR_RETURN(x_indices.shape().At(i), x_values.shape().At(i)); + } + + const int64_t n = x_indices.shape().elem_cnt(); + const int64_t m = x_values.shape().elem_cnt() / n; + user_op::TensorDesc* y_indices = ctx->OutputTensorDesc("y_indices", 0); + user_op::TensorDesc* y_values = ctx->OutputTensorDesc("y_values", 0); + *y_indices = x_indices; + *y_indices->mut_shape() = Shape({n}); + *y_values = x_values; + *y_values->mut_shape() = Shape({n, m}); + user_op::TensorDesc* num_unique = ctx->OutputTensorDesc("num_unique", 0); + *num_unique->mut_shape() = Shape({1}); + return Maybe::Ok(); +} + +/*static*/ Maybe IndexedSlicesReduceSumOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe IndexedSlicesReduceSumOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe IndexedSlicesReduceSumOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_indices = ctx->InputTensorDesc("x_indices", 0); + CHECK_OR_RETURN(IsIndexDataType(x_indices.data_type())); + user_op::TensorDesc* num_unique = ctx->OutputTensorDesc("num_unique", 0); + *num_unique->mut_data_type() = DataType::kInt64; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/kl_div_op.cpp b/oneflow/user/ops/kl_div_op.cpp index fce4255f74d..cb58f29764b 100644 --- a/oneflow/user/ops/kl_div_op.cpp +++ b/oneflow/user/ops/kl_div_op.cpp @@ -16,10 +16,11 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/loss_op_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { -Maybe InferTensorDescFn(user_op::InferContext* ctx) { +Maybe KlInferTensorDescFn(user_op::InferContext* ctx) { const auto& input_desc = ctx->InputTensorDesc("input", 0); const auto& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic()); @@ -32,7 +33,7 @@ Maybe InferTensorDescFn(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe InferDataType(user_op::InferContext* ctx) { +Maybe KlInferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type()); @@ -69,48 +70,58 @@ Maybe InferGradDataType(user_op::InferContext* ctx) { } // namespace -REGISTER_USER_OP("kl_div_loss") - .Input("input") - .Input("target") - .Output("out") - .Attr("log_target") - .SetTensorDescInferFn(InferTensorDescFn) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); - CHECK_OR_RETURN(target_modifier != nullptr); - target_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetDataTypeInferFn(InferDataType) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); - FOR_RANGE(int64_t, i, 0, input_shape.NumAxes()) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg("out", 0), i).Build(); - } - return Maybe::Ok(); - }); +/* static */ Maybe KlDivLossOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return KlInferTensorDescFn(ctx); +} -REGISTER_USER_OP("kl_div_loss_grad") - .Input("input") - .Input("target") - .Input("dy") - .Output("dx") - .Attr("log_target") - .SetTensorDescInferFn(InferGradTensorDescFn) - .SetDataTypeInferFn(InferGradDataType) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); - FOR_RANGE(int64_t, i, 0, input_shape.NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("input", 0), i) - .Split(user_op::OpArg("target", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Build(); - } - return Maybe::Ok(); - }); +/*static*/ Maybe KlDivLossOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe KlDivLossOp::GetSbp(user_op::SbpContext* ctx) { + const auto& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); + FOR_RANGE(int64_t, i, 0, input_shape.NumAxes()) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe KlDivLossOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); + CHECK_OR_RETURN(target_modifier != nullptr); + target_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe KlDivLossOp::InferDataType(user_op::InferContext* ctx) { + return KlInferDataType(ctx); +} + +/* static */ Maybe KlDivLossGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferGradTensorDescFn(ctx); +} + +/*static*/ Maybe KlDivLossGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe KlDivLossGradOp::GetSbp(user_op::SbpContext* ctx) { + const auto& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); + FOR_RANGE(int64_t, i, 0, input_shape.NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("input", 0), i) + .Split(user_op::OpArg("target", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe KlDivLossGradOp::InferDataType(user_op::InferContext* ctx) { + return InferGradDataType(ctx); +} REGISTER_USER_OP_GRAD("kl_div_loss") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/l1_l2_regularize_gradient_op.cpp b/oneflow/user/ops/l1_l2_regularize_gradient_op.cpp index c99910e689c..05affa22404 100644 --- a/oneflow/user/ops/l1_l2_regularize_gradient_op.cpp +++ b/oneflow/user/ops/l1_l2_regularize_gradient_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -38,20 +39,26 @@ Maybe GetSbpSignatures(user_op::SbpContext* ctx) { } // namespace -REGISTER_NO_GRAD_USER_OP("l1_l2_regularize_gradient") - .Input("model") - .Input("model_diff") - .Output("out") - .Attr("l1", 0) - .Attr("l2", 0) - .SetTensorDescInferFn(InferTensorDesc) - .SetGetSbpFn(GetSbpSignatures) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); - const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", 0); - CHECK_EQ_OR_RETURN(model_diff.data_type(), model.data_type()); - *ctx->OutputDType("out", 0) = ctx->InputDType("model", 0); - return Maybe::Ok(); - }); +/* static */ Maybe L1L2RegularizeGradientOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return InferTensorDesc(ctx); +} + +/*static*/ Maybe L1L2RegularizeGradientOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe L1L2RegularizeGradientOp::GetSbp(user_op::SbpContext* ctx) { + return GetSbpSignatures(ctx); +} + +/* static */ Maybe L1L2RegularizeGradientOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& model = ctx->InputTensorDesc("model", 0); + const user_op::TensorDesc& model_diff = ctx->InputTensorDesc("model_diff", 0); + CHECK_EQ_OR_RETURN(model_diff.data_type(), model.data_type()); + *ctx->OutputDType("out", 0) = ctx->InputDType("model", 0); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/l2_normalize_op.cpp b/oneflow/user/ops/l2_normalize_op.cpp index 1e8abab4651..d1723c41c97 100644 --- a/oneflow/user/ops/l2_normalize_op.cpp +++ b/oneflow/user/ops/l2_normalize_op.cpp @@ -14,98 +14,98 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("l2_normalize") - .Input("x") - .Output("y") - .Output("square_x_sum") - .Attr("axis") - .Attr("epsilon") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - Shape* y_shape = ctx->OutputShape("y", 0); - Shape* square_x_sum_shape = ctx->OutputShape("square_x_sum", 0); - const int32_t axis = ctx->Attr("axis"); - const float epsilon = ctx->Attr("epsilon"); - CHECK_GE_OR_RETURN(axis, 0); - CHECK_LT_OR_RETURN(axis, x_shape.NumAxes()); - CHECK_GT_OR_RETURN(epsilon, 0); - *y_shape = x_shape; - *square_x_sum_shape = x_shape; - square_x_sum_shape->Set(axis, 1); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - const int32_t axis = ctx->Attr("axis"); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - if (i != axis) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("y", 0), i) - .Split(user_op::OpArg("square_x_sum", 0), i) - .Build(); - } - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("square_x_sum", 0) = ctx->InputDType("x", 0); - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/* static */ Maybe L2NormalizeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + Shape* y_shape = ctx->OutputShape("y", 0); + Shape* square_x_sum_shape = ctx->OutputShape("square_x_sum", 0); + const int32_t axis = ctx->Attr("axis"); + const float epsilon = ctx->Attr("epsilon"); + CHECK_GE_OR_RETURN(axis, 0); + CHECK_LT_OR_RETURN(axis, x_shape.NumAxes()); + CHECK_GT_OR_RETURN(epsilon, 0); + *y_shape = x_shape; + *square_x_sum_shape = x_shape; + square_x_sum_shape->Set(axis, 1); + return Maybe::Ok(); +} -REGISTER_USER_OP("l2_normalize_grad") - .Input("dy") - .Input("y") - .Input("square_x_sum") - .Output("dx") - .Attr("axis") - .Attr("epsilon") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - const Shape& y_shape = ctx->InputShape("y", 0); - const Shape& square_x_sum_shape = ctx->InputShape("square_x_sum", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - const int32_t axis = ctx->Attr("axis"); - const float epsilon = ctx->Attr("epsilon"); - CHECK_EQ_OR_RETURN(dy_shape, y_shape); - CHECK_GE_OR_RETURN(axis, 0); - CHECK_LT_OR_RETURN(axis, dy_shape.NumAxes()); - CHECK_GT_OR_RETURN(epsilon, 0); - FOR_RANGE(int32_t, i, 0, dy_shape.NumAxes()) { - if (i == axis) { - CHECK_EQ_OR_RETURN(square_x_sum_shape.At(i), 1); - } else { - CHECK_EQ_OR_RETURN(square_x_sum_shape.At(i), dy_shape.At(i)); - } - } - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); - const int32_t axis = ctx->Attr("axis"); - FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) { - if (i != axis) { - ctx->NewBuilder() - .Split(user_op::OpArg("y", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("square_x_sum", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("y", 0), ctx->InputDType("dy", 0)); - CHECK_EQ_OR_RETURN(ctx->InputDType("y", 0), ctx->InputDType("square_x_sum", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe L2NormalizeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe L2NormalizeOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + const int32_t axis = ctx->Attr("axis"); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + if (i != axis) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("y", 0), i) + .Split(user_op::OpArg("square_x_sum", 0), i) + .Build(); + } + } + return Maybe::Ok(); +} + +/* static */ Maybe L2NormalizeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("square_x_sum", 0) = ctx->InputDType("x", 0); + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} + +/* static */ Maybe L2NormalizeGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + const Shape& y_shape = ctx->InputShape("y", 0); + const Shape& square_x_sum_shape = ctx->InputShape("square_x_sum", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + const int32_t axis = ctx->Attr("axis"); + const float epsilon = ctx->Attr("epsilon"); + CHECK_EQ_OR_RETURN(dy_shape, y_shape); + CHECK_GE_OR_RETURN(axis, 0); + CHECK_LT_OR_RETURN(axis, dy_shape.NumAxes()); + CHECK_GT_OR_RETURN(epsilon, 0); + FOR_RANGE(int32_t, i, 0, dy_shape.NumAxes()) { + if (i == axis) { + CHECK_EQ_OR_RETURN(square_x_sum_shape.At(i), 1); + } else { + CHECK_EQ_OR_RETURN(square_x_sum_shape.At(i), dy_shape.At(i)); + } + } + *dx_shape = dy_shape; + return Maybe::Ok(); +} + +/*static*/ Maybe L2NormalizeGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe L2NormalizeGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); + const int32_t axis = ctx->Attr("axis"); + FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) { + if (i != axis) { + ctx->NewBuilder() + .Split(user_op::OpArg("y", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("square_x_sum", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + } + return Maybe::Ok(); +} + +/* static */ Maybe L2NormalizeGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("y", 0), ctx->InputDType("dy", 0)); + CHECK_EQ_OR_RETURN(ctx->InputDType("y", 0), ctx->InputDType("square_x_sum", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("l2_normalize") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/layer_norm_op.cpp b/oneflow/user/ops/layer_norm_op.cpp index 351c1d1d42a..d51522d4405 100644 --- a/oneflow/user/ops/layer_norm_op.cpp +++ b/oneflow/user/ops/layer_norm_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -40,242 +41,228 @@ oneflow::DataType InferBnParamDataType(const DataType x_data_type) { } // namespace -REGISTER_USER_OP("layer_norm") - .Input("x") - .OptionalInput("beta") - .OptionalInput("gamma") - .Output("y") - .Output("mean") - .Output("inv_variance") - .OptionalOutput("normalized") - .Attr("center") - .Attr("scale") - .Attr("begin_norm_axis") - .Attr("begin_params_axis") - .Attr("epsilon") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); - user_op::TensorDesc* mean = ctx->OutputTensorDesc("mean", 0); - user_op::TensorDesc* inv_variance = ctx->OutputTensorDesc("inv_variance", 0); - const bool center = ctx->Attr("center"); - const bool scale = ctx->Attr("scale"); - const int64_t begin_params_axis = - ShiftNegativeAxisIfNeed(x.shape(), ctx->Attr("begin_params_axis")); - *y->mut_shape() = x.shape(); - *y->mut_is_dynamic() = x.is_dynamic(); - DimVector param_shape_dim_vec; - param_shape_dim_vec.insert(param_shape_dim_vec.end(), - x.shape().dim_vec().cbegin() + begin_params_axis, - x.shape().dim_vec().cend()); - const Shape param_shape(param_shape_dim_vec); - if (center) { - const user_op::TensorDesc& beta = ctx->InputTensorDesc("beta", 0); - CHECK_EQ_OR_RETURN(beta.shape(), param_shape); - } - if (scale) { - user_op::TensorDesc* normalized = ctx->OutputTensorDesc("normalized", 0); - const user_op::TensorDesc& gamma = ctx->InputTensorDesc("gamma", 0); - CHECK_EQ_OR_RETURN(gamma.shape(), param_shape); - *normalized = x; - } - const int64_t begin_norm_axis = - ShiftNegativeAxisIfNeed(x.shape(), ctx->Attr("begin_norm_axis")); - *mean->mut_shape() = InferBnParamShape(x.shape(), begin_norm_axis); - *inv_variance = *mean; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); - int64_t begin_norm_axis = - ShiftNegativeAxisIfNeed(x_shape, ctx->Attr("begin_norm_axis")); - int64_t begin_params_axis = - ShiftNegativeAxisIfNeed(x_shape, ctx->Attr("begin_params_axis")); - for (int i = 0; i < std::min(begin_norm_axis, begin_params_axis); ++i) { - ctx->NewBuilder() - .Split(ctx->inputs(), i) - .Split(ctx->outputs(), i) - .Broadcast(user_op::OpArg("gamma", 0)) - .Broadcast(user_op::OpArg("beta", 0)) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const bool center = ctx->Attr("center"); - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); - *y->mut_data_type() = x.data_type(); - if (center) { - const user_op::TensorDesc& beta = ctx->InputTensorDesc("beta", 0); - CHECK_EQ_OR_RETURN(beta.data_type(), x.data_type()); - } - const bool scale = ctx->Attr("scale"); - if (scale) { - const user_op::TensorDesc& gamma = ctx->InputTensorDesc("gamma", 0); - CHECK_EQ_OR_RETURN(gamma.data_type(), x.data_type()); - } - user_op::TensorDesc* mean = ctx->OutputTensorDesc("mean", 0); - user_op::TensorDesc* inv_variance = ctx->OutputTensorDesc("inv_variance", 0); - *mean->mut_data_type() = InferBnParamDataType(x.data_type()); - *inv_variance->mut_data_type() = mean->data_type(); - return Maybe::Ok(); - }); +/* static */ Maybe LayerNormOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + user_op::TensorDesc* mean = ctx->OutputTensorDesc("mean", 0); + user_op::TensorDesc* inv_variance = ctx->OutputTensorDesc("inv_variance", 0); + const bool center = ctx->Attr("center"); + const bool scale = ctx->Attr("scale"); + const int64_t begin_params_axis = + ShiftNegativeAxisIfNeed(x.shape(), ctx->Attr("begin_params_axis")); + *y->mut_shape() = x.shape(); + *y->mut_is_dynamic() = x.is_dynamic(); + DimVector param_shape_dim_vec; + param_shape_dim_vec.insert(param_shape_dim_vec.end(), + x.shape().dim_vec().cbegin() + begin_params_axis, + x.shape().dim_vec().cend()); + const Shape param_shape(param_shape_dim_vec); + if (center) { + const user_op::TensorDesc& beta = ctx->InputTensorDesc("beta", 0); + CHECK_EQ_OR_RETURN(beta.shape(), param_shape); + } + if (scale) { + user_op::TensorDesc* normalized = ctx->OutputTensorDesc("normalized", 0); + const user_op::TensorDesc& gamma = ctx->InputTensorDesc("gamma", 0); + CHECK_EQ_OR_RETURN(gamma.shape(), param_shape); + *normalized = x; + } + const int64_t begin_norm_axis = + ShiftNegativeAxisIfNeed(x.shape(), ctx->Attr("begin_norm_axis")); + *mean->mut_shape() = InferBnParamShape(x.shape(), begin_norm_axis); + *inv_variance = *mean; + return Maybe::Ok(); +} -REGISTER_USER_OP("layer_norm_grad") - .Input("dy") - .Input("x") - .Input("mean") - .Input("inv_variance") - .OptionalInput("_add_to_output") - .Output("dx") - .Attr("begin_norm_axis") - .Attr("epsilon") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0); - const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0); - user_op::TensorDesc* dx = ctx->OutputTensorDesc("dx", 0); - CHECK_EQ_OR_RETURN(dy.shape(), x.shape()); - const int64_t begin_norm_axis = ctx->Attr("begin_norm_axis"); - CHECK_GT_OR_RETURN(begin_norm_axis, 0); - const Shape& bn_param_shape = InferBnParamShape(x.shape(), begin_norm_axis); - CHECK_EQ_OR_RETURN(mean.shape(), bn_param_shape); - CHECK_EQ_OR_RETURN(inv_variance.shape(), bn_param_shape); - *dx->mut_shape() = dy.shape(); - *dx->mut_is_dynamic() = dy.is_dynamic(); - if (ctx->has_input("_add_to_output", 0)) { - const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); - CHECK_EQ_OR_RETURN(add_to_output.shape(), dx->shape()); - } - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - int64_t begin_norm_axis = ctx->Attr("begin_norm_axis"); - for (int i = 0; i < begin_norm_axis; ++i) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - CHECK_EQ_OR_RETURN(dy.data_type(), x.data_type()); - const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0); - const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0); - const DataType& bn_param_data_type = InferBnParamDataType(x.data_type()); - CHECK_EQ_OR_RETURN(mean.data_type(), bn_param_data_type); - CHECK_EQ_OR_RETURN(inv_variance.data_type(), bn_param_data_type); - user_op::TensorDesc* dx = ctx->OutputTensorDesc("dx", 0); - *dx->mut_data_type() = dy.data_type(); - if (ctx->has_input("_add_to_output", 0)) { - const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); - CHECK_EQ_OR_RETURN(add_to_output.data_type(), dx->data_type()); - } - return Maybe::Ok(); - }); +/*static*/ Maybe LayerNormOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} -REGISTER_USER_OP("layer_norm_param_grad") - .Input("dy") - .OptionalInput("normalized") - .OptionalInput("gamma") - .OptionalOutput("normalized_diff") - .OptionalOutput("beta_diff") - .OptionalOutput("gamma_diff") - .OptionalOutput("reduce_buf") - .Attr("begin_params_axis") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - // TODO: tsai: replace lambda with user op if - auto has_tensor = [ctx](const std::string& bn) -> bool { - bool ret = false; - for (auto t : ctx->inputs()) { - if (bn == t.first) { return true; } - } - for (auto t : ctx->outputs()) { - if (bn == t.first) { return true; } - } - return ret; - }; - const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); - const int64_t begin_params_axis = ctx->Attr("begin_params_axis"); - const bool has_beta_diff = has_tensor("beta_diff"); - const bool has_gamma_diff = has_tensor("gamma_diff"); - const bool has_gamma = has_tensor("gamma"); - const bool has_normalized_diff = has_tensor("normalized_diff"); - if (has_beta_diff || has_gamma_diff) { - user_op::TensorDesc* reduce_buf = ctx->OutputTensorDesc("reduce_buf", 0); - *reduce_buf = dy; - } - CHECK_GE_OR_RETURN(begin_params_axis, 1); - CHECK_LT_OR_RETURN(begin_params_axis, dy.shape().NumAxes()); - DimVector param_shape_dim_vec; - param_shape_dim_vec.insert(param_shape_dim_vec.end(), - dy.shape().dim_vec().cbegin() + begin_params_axis, - dy.shape().dim_vec().cend()); - const Shape param_shape(param_shape_dim_vec); - if (has_beta_diff) { - user_op::TensorDesc* beta_diff = ctx->OutputTensorDesc("beta_diff", 0); - *beta_diff->mut_shape() = param_shape; - } - if (has_gamma_diff) { - user_op::TensorDesc* gamma_diff = ctx->OutputTensorDesc("gamma_diff", 0); - *gamma_diff->mut_shape() = param_shape; - } - if (has_normalized_diff) { - user_op::TensorDesc* normalized_diff = ctx->OutputTensorDesc("normalized_diff", 0); - *normalized_diff = dy; - } - if (has_gamma) { - const user_op::TensorDesc& gamma = ctx->InputTensorDesc("gamma", 0); - CHECK_EQ_OR_RETURN(gamma.shape(), param_shape); - } - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - int64_t begin_params_axis = ctx->Attr("begin_params_axis"); - for (int i = 0; i < begin_params_axis; ++i) { - ctx->NewBuilder() - .Split(ctx->inputs(), i) - .Split(ctx->outputs(), i) - .Broadcast(user_op::OpArg("gamma", 0)) - .PartialSum(user_op::OpArg("gamma_diff", 0)) - .PartialSum(user_op::OpArg("beta_diff", 0)) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - auto has_tensor = [ctx](const std::string& bn) -> bool { - bool ret = false; - for (auto& t : ctx->inputs()) { - if (bn == t.first) { return true; } - } - for (auto& t : ctx->outputs()) { - if (bn == t.first) { return true; } - } - return ret; - }; - const bool has_beta_diff = has_tensor("beta_diff"); - const bool has_gamma_diff = has_tensor("gamma_diff"); - const bool has_gamma = has_tensor("gamma"); - const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); - if (has_beta_diff) { - user_op::TensorDesc* beta_diff = ctx->OutputTensorDesc("beta_diff", 0); - *beta_diff->mut_data_type() = dy.data_type(); - } - if (has_gamma_diff) { - user_op::TensorDesc* gamma_diff = ctx->OutputTensorDesc("gamma_diff", 0); - const user_op::TensorDesc& normalized = ctx->InputTensorDesc("normalized", 0); - CHECK_EQ_OR_RETURN(normalized.data_type(), normalized.data_type()); - *gamma_diff->mut_data_type() = dy.data_type(); - } - if (has_gamma) { - const user_op::TensorDesc& gamma = ctx->InputTensorDesc("gamma", 0); - CHECK_EQ_OR_RETURN(gamma.data_type(), dy.data_type()); - } - return Maybe::Ok(); - }); +/* static */ Maybe LayerNormOp::GetSbp(user_op::SbpContext* ctx) { + const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); + int64_t begin_norm_axis = ShiftNegativeAxisIfNeed(x_shape, ctx->Attr("begin_norm_axis")); + int64_t begin_params_axis = + ShiftNegativeAxisIfNeed(x_shape, ctx->Attr("begin_params_axis")); + for (int i = 0; i < std::min(begin_norm_axis, begin_params_axis); ++i) { + ctx->NewBuilder() + .Split(ctx->inputs(), i) + .Split(ctx->outputs(), i) + .Broadcast(user_op::OpArg("gamma", 0)) + .Broadcast(user_op::OpArg("beta", 0)) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe LayerNormOp::InferDataType(user_op::InferContext* ctx) { + const bool center = ctx->Attr("center"); + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + *y->mut_data_type() = x.data_type(); + if (center) { + const user_op::TensorDesc& beta = ctx->InputTensorDesc("beta", 0); + CHECK_EQ_OR_RETURN(beta.data_type(), x.data_type()); + } + const bool scale = ctx->Attr("scale"); + if (scale) { + const user_op::TensorDesc& gamma = ctx->InputTensorDesc("gamma", 0); + CHECK_EQ_OR_RETURN(gamma.data_type(), x.data_type()); + } + user_op::TensorDesc* mean = ctx->OutputTensorDesc("mean", 0); + user_op::TensorDesc* inv_variance = ctx->OutputTensorDesc("inv_variance", 0); + *mean->mut_data_type() = InferBnParamDataType(x.data_type()); + *inv_variance->mut_data_type() = mean->data_type(); + return Maybe::Ok(); +} + +/* static */ Maybe LayerNormGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0); + const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0); + user_op::TensorDesc* dx = ctx->OutputTensorDesc("dx", 0); + CHECK_EQ_OR_RETURN(dy.shape(), x.shape()); + const int64_t begin_norm_axis = ctx->Attr("begin_norm_axis"); + CHECK_GT_OR_RETURN(begin_norm_axis, 0); + const Shape& bn_param_shape = InferBnParamShape(x.shape(), begin_norm_axis); + CHECK_EQ_OR_RETURN(mean.shape(), bn_param_shape); + CHECK_EQ_OR_RETURN(inv_variance.shape(), bn_param_shape); + *dx->mut_shape() = dy.shape(); + *dx->mut_is_dynamic() = dy.is_dynamic(); + if (ctx->has_input("_add_to_output", 0)) { + const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); + CHECK_EQ_OR_RETURN(add_to_output.shape(), dx->shape()); + } + return Maybe::Ok(); +} + +/*static*/ Maybe LayerNormGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe LayerNormGradOp::GetSbp(user_op::SbpContext* ctx) { + int64_t begin_norm_axis = ctx->Attr("begin_norm_axis"); + for (int i = 0; i < begin_norm_axis; ++i) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe LayerNormGradOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + CHECK_EQ_OR_RETURN(dy.data_type(), x.data_type()); + const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0); + const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0); + const DataType& bn_param_data_type = InferBnParamDataType(x.data_type()); + CHECK_EQ_OR_RETURN(mean.data_type(), bn_param_data_type); + CHECK_EQ_OR_RETURN(inv_variance.data_type(), bn_param_data_type); + user_op::TensorDesc* dx = ctx->OutputTensorDesc("dx", 0); + *dx->mut_data_type() = dy.data_type(); + if (ctx->has_input("_add_to_output", 0)) { + const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); + CHECK_EQ_OR_RETURN(add_to_output.data_type(), dx->data_type()); + } + return Maybe::Ok(); +} + +/* static */ Maybe LayerNormParamGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + // TODO: tsai: replace lambda with user op if + auto has_tensor = [ctx](const std::string& bn) -> bool { + bool ret = false; + for (const auto& t : ctx->inputs()) { + if (bn == t.first) { return true; } + } + for (const auto& t : ctx->outputs()) { + if (bn == t.first) { return true; } + } + return ret; + }; + const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); + const int64_t begin_params_axis = ctx->Attr("begin_params_axis"); + const bool has_beta_diff = has_tensor("beta_diff"); + const bool has_gamma_diff = has_tensor("gamma_diff"); + const bool has_gamma = has_tensor("gamma"); + const bool has_normalized_diff = has_tensor("normalized_diff"); + if (has_beta_diff || has_gamma_diff) { + user_op::TensorDesc* reduce_buf = ctx->OutputTensorDesc("reduce_buf", 0); + *reduce_buf = dy; + } + CHECK_GE_OR_RETURN(begin_params_axis, 1); + CHECK_LT_OR_RETURN(begin_params_axis, dy.shape().NumAxes()); + DimVector param_shape_dim_vec; + param_shape_dim_vec.insert(param_shape_dim_vec.end(), + dy.shape().dim_vec().cbegin() + begin_params_axis, + dy.shape().dim_vec().cend()); + const Shape param_shape(param_shape_dim_vec); + if (has_beta_diff) { + user_op::TensorDesc* beta_diff = ctx->OutputTensorDesc("beta_diff", 0); + *beta_diff->mut_shape() = param_shape; + } + if (has_gamma_diff) { + user_op::TensorDesc* gamma_diff = ctx->OutputTensorDesc("gamma_diff", 0); + *gamma_diff->mut_shape() = param_shape; + } + if (has_normalized_diff) { + user_op::TensorDesc* normalized_diff = ctx->OutputTensorDesc("normalized_diff", 0); + *normalized_diff = dy; + } + if (has_gamma) { + const user_op::TensorDesc& gamma = ctx->InputTensorDesc("gamma", 0); + CHECK_EQ_OR_RETURN(gamma.shape(), param_shape); + } + return Maybe::Ok(); +} + +/*static*/ Maybe LayerNormParamGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe LayerNormParamGradOp::GetSbp(user_op::SbpContext* ctx) { + int64_t begin_params_axis = ctx->Attr("begin_params_axis"); + for (int i = 0; i < begin_params_axis; ++i) { + ctx->NewBuilder() + .Split(ctx->inputs(), i) + .Split(ctx->outputs(), i) + .Broadcast(user_op::OpArg("gamma", 0)) + .PartialSum(user_op::OpArg("gamma_diff", 0)) + .PartialSum(user_op::OpArg("beta_diff", 0)) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe LayerNormParamGradOp::InferDataType(user_op::InferContext* ctx) { + auto has_tensor = [ctx](const std::string& bn) -> bool { + bool ret = false; + for (auto& t : ctx->inputs()) { + if (bn == t.first) { return true; } + } + for (auto& t : ctx->outputs()) { + if (bn == t.first) { return true; } + } + return ret; + }; + const bool has_beta_diff = has_tensor("beta_diff"); + const bool has_gamma_diff = has_tensor("gamma_diff"); + const bool has_gamma = has_tensor("gamma"); + const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0); + if (has_beta_diff) { + user_op::TensorDesc* beta_diff = ctx->OutputTensorDesc("beta_diff", 0); + *beta_diff->mut_data_type() = dy.data_type(); + } + if (has_gamma_diff) { + user_op::TensorDesc* gamma_diff = ctx->OutputTensorDesc("gamma_diff", 0); + const user_op::TensorDesc& normalized = ctx->InputTensorDesc("normalized", 0); + CHECK_EQ_OR_RETURN(normalized.data_type(), normalized.data_type()); + *gamma_diff->mut_data_type() = dy.data_type(); + } + if (has_gamma) { + const user_op::TensorDesc& gamma = ctx->InputTensorDesc("gamma", 0); + CHECK_EQ_OR_RETURN(gamma.data_type(), dy.data_type()); + } + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("layer_norm") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/leaky_relu_op.cpp b/oneflow/user/ops/leaky_relu_op.cpp index f48b34aadd5..09d8b318c54 100644 --- a/oneflow/user/ops/leaky_relu_op.cpp +++ b/oneflow/user/ops/leaky_relu_op.cpp @@ -14,65 +14,69 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("leaky_relu") - .Input("x") - .Output("y") - .Attr("alpha") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - Shape* y_shape = ctx->OutputShape("y", 0); - *y_shape = x_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/* static */ Maybe LeakyReluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + Shape* y_shape = ctx->OutputShape("y", 0); + *y_shape = x_shape; + return Maybe::Ok(); +} -REGISTER_USER_OP("leaky_relu_grad") - .Input("x") - .Input("dy") - .Output("dx") - .Attr("alpha") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(dy_shape == x_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - ctx->NewBuilder() - .Broadcast(user_op::OpArg("x", 0)) - .PartialSum(user_op::OpArg("dy", 0)) - .PartialSum(user_op::OpArg("dx", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe LeakyReluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe LeakyReluOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe LeakyReluOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} + +/* static */ Maybe LeakyReluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(dy_shape == x_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} + +/*static*/ Maybe LeakyReluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe LeakyReluGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + ctx->NewBuilder() + .Broadcast(user_op::OpArg("x", 0)) + .PartialSum(user_op::OpArg("dy", 0)) + .PartialSum(user_op::OpArg("dx", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe LeakyReluGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("leaky_relu") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/log_softmax_op.cpp b/oneflow/user/ops/log_softmax_op.cpp index 6eff6fb15a3..d8cffbf7460 100644 --- a/oneflow/user/ops/log_softmax_op.cpp +++ b/oneflow/user/ops/log_softmax_op.cpp @@ -14,61 +14,65 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/* static */ Maybe LogSoftmaxOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("prob", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("log_softmax") - .Input("in") - .Output("prob") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("prob", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes() - 1) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), axis) - .Split(user_op::OpArg("prob", 0), axis) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("prob", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe LogSoftmaxOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} -REGISTER_USER_OP("log_softmax_grad") - .Input("prob") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& y_shape = ctx->InputShape("prob", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(dy_shape == y_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("prob", 0), ctx->InputDType("dy", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("prob", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("prob", 0); - FOR_RANGE(int64_t, axis, 0, y_tensor.shape().NumAxes() - 1) { - ctx->NewBuilder() - .Split(user_op::OpArg("prob", 0), axis) - .Split(user_op::OpArg("dy", 0), axis) - .Split(user_op::OpArg("dx", 0), axis) - .Build(); - } - return Maybe::Ok(); - }); +/* static */ Maybe LogSoftmaxOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes() - 1) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), axis) + .Split(user_op::OpArg("prob", 0), axis) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe LogSoftmaxOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("prob", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe LogSoftmaxGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& y_shape = ctx->InputShape("prob", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(dy_shape == y_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} + +/*static*/ Maybe LogSoftmaxGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe LogSoftmaxGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("prob", 0); + FOR_RANGE(int64_t, axis, 0, y_tensor.shape().NumAxes() - 1) { + ctx->NewBuilder() + .Split(user_op::OpArg("prob", 0), axis) + .Split(user_op::OpArg("dy", 0), axis) + .Split(user_op::OpArg("dx", 0), axis) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe LogSoftmaxGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("prob", 0), ctx->InputDType("dy", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("prob", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("log_softmax") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -87,5 +91,4 @@ REGISTER_USER_OP_GRAD("log_softmax") return Maybe::Ok(); }); -} // namespace } // namespace oneflow diff --git a/oneflow/user/ops/logical_not_op.cpp b/oneflow/user/ops/logical_not_op.cpp index 47d5b1ae20e..c4f549fa7ce 100644 --- a/oneflow/user/ops/logical_not_op.cpp +++ b/oneflow/user/ops/logical_not_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -26,11 +27,20 @@ Maybe InferDataTypeLogicalNot(user_op::InferContext* ctx) { } // namespace -REGISTER_NO_GRAD_USER_OP("logical_not") - .Input("x") - .Output("y") - .SetTensorDescInferFn(user_op::TensorDescInferFnUtil::Unchanged) - .SetGetSbpFn(user_op::GetSbpFnUtil::SplitForEachAxis) - .SetDataTypeInferFn(InferDataTypeLogicalNot); +/* static */ Maybe LogicalNotOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return user_op::TensorDescInferFnUtil::Unchanged(ctx); +} + +/*static*/ Maybe LogicalNotOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe LogicalNotOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); +} + +/* static */ Maybe LogicalNotOp::InferDataType(user_op::InferContext* ctx) { + return InferDataTypeLogicalNot(ctx); +} } // namespace oneflow diff --git a/oneflow/user/ops/masked_fill_op.cpp b/oneflow/user/ops/masked_fill_op.cpp index aa6e08ba954..44afd9b37e9 100644 --- a/oneflow/user/ops/masked_fill_op.cpp +++ b/oneflow/user/ops/masked_fill_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -73,17 +74,25 @@ Maybe GetMaskedFillInputArgModify(const user_op::GetInputArgModifier& GetI } // namespace -REGISTER_USER_OP("masked_fill") - .Input("x") - .Input("mask") - .Output("out") - .Attr("has_int_operand") - .Attr("has_float_operand") - .Attr("int_operand") - .Attr("float_operand") - .SetTensorDescInferFn(InferMaskedFillTensorDesc) - .SetInputArgModifyFn(GetMaskedFillInputArgModify) - .SetDataTypeInferFn(InferMaskedFillDataType) - .SetGetSbpFn(GetMaskedFillSbpSignatures); +/* static */ Maybe MaskedFillOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferMaskedFillTensorDesc(ctx); +} + +/*static*/ Maybe MaskedFillOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe MaskedFillOp::GetSbp(user_op::SbpContext* ctx) { + return GetMaskedFillSbpSignatures(ctx); +} + +/* static */ Maybe MaskedFillOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return GetMaskedFillInputArgModify(GetInputArgModifierFn, conf); +} + +/* static */ Maybe MaskedFillOp::InferDataType(user_op::InferContext* ctx) { + return InferMaskedFillDataType(ctx); +} } // namespace oneflow diff --git a/oneflow/user/ops/math_binary_broadcast_ops.cpp b/oneflow/user/ops/math_binary_broadcast_ops.cpp index d246629ba6a..918884679b0 100644 --- a/oneflow/user/ops/math_binary_broadcast_ops.cpp +++ b/oneflow/user/ops/math_binary_broadcast_ops.cpp @@ -16,6 +16,7 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/core/ndarray/binary_func.h" #include "oneflow/user/ops/math_binary_broadcast_seq.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -193,26 +194,36 @@ Maybe GetBinaryBroadcastSbpSignature(user_op::SbpContext* ctx) { } // namespace -#define REGISTER_BINARY_BROADCAST_NORMAL_USER_OP(op_name, suffix) \ - REGISTER_USER_OP(op_name) \ - .Input("x") \ - .Input("y") \ - .Output("z") \ - .SetTensorDescInferFn(InferTensorDescBinaryBroadcastNormal) \ - .SetGetSbpFn(GetBinaryBroadcastSbpSignature) \ - .SetDataTypeInferFn(InferDataTypeBinaryBroadcastNormal); - -#define REGISTER_BINARY_BROADCAST_LOGICAL_USER_OP(op_name, suffix) \ - REGISTER_NO_GRAD_USER_OP(op_name) \ - .Input("x") \ - .Input("y") \ - .Output("z") \ - .SetTensorDescInferFn(InferTensorDescBinaryBroadcastLogical) \ - .SetGetSbpFn(GetBinaryBroadcastSbpSignature) \ - .SetDataTypeInferFn(InferDataTypeBinaryBroadcastLogical); - -OF_PP_FOR_EACH_TUPLE(REGISTER_BINARY_BROADCAST_NORMAL_USER_OP, MATH_BINARY_BROADCAST_FUNC_SEQ) +#define REGISTER_BINARY_BROADCAST_NORMAL_USER_OP(op_name, suffix) \ + /* static */ Maybe op_name::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return InferTensorDescBinaryBroadcastNormal(ctx); \ + } \ + /*static*/ Maybe op_name::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /* static */ Maybe op_name::GetSbp(user_op::SbpContext* ctx) { \ + return GetBinaryBroadcastSbpSignature(ctx); \ + } \ + /* static */ Maybe op_name::InferDataType(user_op::InferContext* ctx) { \ + return InferDataTypeBinaryBroadcastNormal(ctx); \ + } + +#define REGISTER_BINARY_BROADCAST_LOGICAL_USER_OP(op_name, suffix) \ + /* static */ Maybe op_name::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return InferTensorDescBinaryBroadcastLogical(ctx); \ + } \ + /*static*/ Maybe op_name::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /* static */ Maybe op_name::GetSbp(user_op::SbpContext* ctx) { \ + return GetBinaryBroadcastSbpSignature(ctx); \ + } \ + /* static */ Maybe op_name::InferDataType(user_op::InferContext* ctx) { \ + return InferDataTypeBinaryBroadcastLogical(ctx); \ + } + +OF_PP_FOR_EACH_TUPLE(REGISTER_BINARY_BROADCAST_NORMAL_USER_OP, MATH_BINARY_BROADCAST_FUNC_SEQ_ODS) OF_PP_FOR_EACH_TUPLE(REGISTER_BINARY_BROADCAST_LOGICAL_USER_OP, - MATH_BINARY_BROADCAST_LOGICAL_FUNC_SEQ) + MATH_BINARY_BROADCAST_LOGICAL_FUNC_SEQ_ODS) } // namespace oneflow diff --git a/oneflow/user/ops/math_binary_broadcast_seq.h b/oneflow/user/ops/math_binary_broadcast_seq.h index c3eeafab202..4dc820c0fc8 100644 --- a/oneflow/user/ops/math_binary_broadcast_seq.h +++ b/oneflow/user/ops/math_binary_broadcast_seq.h @@ -42,6 +42,28 @@ namespace oneflow { OF_PP_MAKE_TUPLE_SEQ("broadcast_logical_or", OR) \ OF_PP_MAKE_TUPLE_SEQ("broadcast_logical_xor", XOR) +#define MATH_BINARY_BROADCAST_FUNC_SEQ_ODS \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastAddOp, Add) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastSubOp, Sub) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastMulOp, Mul) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastDivOp, Div) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastMinimumOp, Min) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastMaximumOp, Max) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastFloorModOp, FloorMod) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastFmodOp, FMod) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastPowOp, Pow) + +#define MATH_BINARY_BROADCAST_LOGICAL_FUNC_SEQ_ODS \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastEqualOp, EQ) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastNotEqualOp, NE) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastGreaterOp, GT) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastGreaterEqualOp, GE) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastLessOp, LT) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastLessEqualOp, LE) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastLogicalAndOp, AND) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastLogicalOrOp, OR) \ + OF_PP_MAKE_TUPLE_SEQ(BroadcastLogicalXorOp, XOR) + } // namespace oneflow #endif // ONEFLOW_USER_OPS_MATH_BINARY_BROADCAST_SEQ_H_ diff --git a/oneflow/user/ops/math_binary_elementwise_ops.cpp b/oneflow/user/ops/math_binary_elementwise_ops.cpp index f528a76966a..ec6e71f82de 100644 --- a/oneflow/user/ops/math_binary_elementwise_ops.cpp +++ b/oneflow/user/ops/math_binary_elementwise_ops.cpp @@ -15,38 +15,34 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/math_binary_elementwise_seq.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -#define MATH_ELEMENTWISE_DEFAULT_SET_FUNC() \ - SetTensorDescInferFn(user_op::TensorDescInferFnUtil::Unchanged) \ - .SetGetSbpFn(user_op::GetSbpFnUtil::SplitForEachAxis) \ - .SetDataTypeInferFn(user_op::TensorDescInferFnUtil::UnchangedDataType) +#define MATH_ELEMENTWISE_DEFAULT_SET_FUNC(op_type) \ + /* static */ Maybe op_type::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return user_op::TensorDescInferFnUtil::Unchanged(ctx); \ + } \ + /*static*/ Maybe op_type::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /* static */ Maybe op_type::GetSbp(user_op::SbpContext* ctx) { \ + return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); \ + } \ + /* static */ Maybe op_type::InferDataType(user_op::InferContext* ctx) { \ + return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx); \ + } #define REGISTER_MATH_BINARY_ELEMENTWISE_OP_AND_GRAD(math_binary_elementwise_type, func_prefix) \ - REGISTER_USER_OP(math_binary_elementwise_type) \ - .Input("x") \ - .Input("y") \ - .Output("z") \ - .MATH_ELEMENTWISE_DEFAULT_SET_FUNC(); \ + MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##Op); \ \ - REGISTER_USER_OP((std::string("") + math_binary_elementwise_type + "_x_grad")) \ - .Input("x") \ - .Input("y") \ - .Input("dz") \ - .Output("dx") \ - .MATH_ELEMENTWISE_DEFAULT_SET_FUNC(); \ + MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##XGradOp); \ \ - REGISTER_USER_OP((std::string("") + math_binary_elementwise_type + "_y_grad")) \ - .Input("x") \ - .Input("y") \ - .Input("dz") \ - .Output("dy") \ - .MATH_ELEMENTWISE_DEFAULT_SET_FUNC(); \ + MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##YGradOp); \ \ REGISTER_USER_OP_GRAD(math_binary_elementwise_type) \ .SetGenBackwardOpConfFn( \ - [](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { \ + [](const user_op::UserOpWrapper& op, const user_op::AddOpFn& AddOp) -> Maybe { \ if (op.NeedGenGradTensor4OpInput("x", 0)) { \ user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_x_grad"); \ user_op::UserOpConfWrapper binary_grad_op = \ @@ -74,6 +70,7 @@ namespace oneflow { return Maybe::Ok(); \ }); -OF_PP_FOR_EACH_TUPLE(REGISTER_MATH_BINARY_ELEMENTWISE_OP_AND_GRAD, MATH_BINARY_ELEMENTWISE_FUNC_SEQ) +OF_PP_FOR_EACH_TUPLE(REGISTER_MATH_BINARY_ELEMENTWISE_OP_AND_GRAD, + MATH_BINARY_ELEMENTWISE_FUNC_SEQ_ODS) } // namespace oneflow diff --git a/oneflow/user/ops/math_binary_elementwise_seq.h b/oneflow/user/ops/math_binary_elementwise_seq.h index 37e667f086e..4cdc682d687 100644 --- a/oneflow/user/ops/math_binary_elementwise_seq.h +++ b/oneflow/user/ops/math_binary_elementwise_seq.h @@ -27,6 +27,13 @@ namespace oneflow { OF_PP_MAKE_TUPLE_SEQ("xdivy", Xdivy) \ OF_PP_MAKE_TUPLE_SEQ("xlogy", Xlogy) +#define MATH_BINARY_ELEMENTWISE_FUNC_SEQ_ODS \ + OF_PP_MAKE_TUPLE_SEQ("pow", Pow) \ + OF_PP_MAKE_TUPLE_SEQ("atan2", Atan2) \ + OF_PP_MAKE_TUPLE_SEQ("floordiv", Floordiv) \ + OF_PP_MAKE_TUPLE_SEQ("xdivy", Xdivy) \ + OF_PP_MAKE_TUPLE_SEQ("xlogy", Xlogy) + } // namespace oneflow #endif // ONEFLOW_USER_OPS_MATH_BINARY_ELEMENTWISE_SEQ_H_ diff --git a/oneflow/user/ops/math_unary_elementwise_op.cpp b/oneflow/user/ops/math_unary_elementwise_op.cpp index 69cc34fb151..64af43f6316 100644 --- a/oneflow/user/ops/math_unary_elementwise_op.cpp +++ b/oneflow/user/ops/math_unary_elementwise_op.cpp @@ -15,40 +15,45 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/math_unary_elementwise_seq.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -#define REGISTER_MATH_UNARY_ELEMENTWISE_OP_AND_GRAD(math_unary_elementwise_type, func_prefix) \ - REGISTER_USER_OP(math_unary_elementwise_type) \ - .Input("x") \ - .Output("y") \ - .SetTensorDescInferFn(user_op::TensorDescInferFnUtil::Unchanged) \ - .SetGetSbpFn(user_op::GetSbpFnUtil::SplitForEachAxis) \ - .SetDataTypeInferFn(user_op::TensorDescInferFnUtil::UnchangedDataType); \ - REGISTER_USER_OP((std::string("") + math_unary_elementwise_type + "_grad")) \ - .Input("x") \ - .Input("dy") \ - .Output("dx") \ - .SetTensorDescInferFn(user_op::TensorDescInferFnUtil::Unchanged) \ - .SetGetSbpFn(user_op::GetSbpFnUtil::SplitForEachAxis) \ - .SetDataTypeInferFn(user_op::TensorDescInferFnUtil::UnchangedDataType); \ - REGISTER_USER_OP_GRAD(math_unary_elementwise_type) \ - .SetGenBackwardOpConfFn( \ - [](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { \ - if (op.NeedGenGradTensor4OpInput("x", 0)) { \ - user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); \ - user_op::UserOpConfWrapper unary_grad_op = \ - builder.Op(std::string("") + math_unary_elementwise_type + "_grad") \ - .Input("x", op.input("x", 0)) \ - .Input("dy", op.GetGradTensorWithOpOutput("y", 0)) \ - .Output("dx") \ - .Build(); \ - op.BindGradTensorWithOpInput(unary_grad_op.output("dx", 0), "x", 0); \ - AddOp(unary_grad_op); \ - } \ - return Maybe::Ok(); \ +#define MATH_ELEMENTWISE_DEFAULT_SET_FUNC(op_type) \ + /* static */ Maybe op_type::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return user_op::TensorDescInferFnUtil::Unchanged(ctx); \ + } \ + /*static*/ Maybe op_type::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /* static */ Maybe op_type::GetSbp(user_op::SbpContext* ctx) { \ + return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); \ + } \ + /* static */ Maybe op_type::InferDataType(user_op::InferContext* ctx) { \ + return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx); \ + } + +#define REGISTER_MATH_UNARY_ELEMENTWISE_OP_AND_GRAD(math_unary_elementwise_type, func_prefix) \ + MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##Op) \ + MATH_ELEMENTWISE_DEFAULT_SET_FUNC(func_prefix##GradOp) \ + REGISTER_USER_OP_GRAD(math_unary_elementwise_type) \ + .SetGenBackwardOpConfFn( \ + [](const user_op::UserOpWrapper& op, const user_op::AddOpFn& AddOp) -> Maybe { \ + if (op.NeedGenGradTensor4OpInput("x", 0)) { \ + user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); \ + user_op::UserOpConfWrapper unary_grad_op = \ + builder.Op(std::string("") + math_unary_elementwise_type + "_grad") \ + .Input("x", op.input("x", 0)) \ + .Input("dy", op.GetGradTensorWithOpOutput("y", 0)) \ + .Output("dx") \ + .Build(); \ + op.BindGradTensorWithOpInput(unary_grad_op.output("dx", 0), "x", 0); \ + AddOp(unary_grad_op); \ + } \ + return Maybe::Ok(); \ }); -OF_PP_FOR_EACH_TUPLE(REGISTER_MATH_UNARY_ELEMENTWISE_OP_AND_GRAD, MATH_UNARY_ELEMENTWISE_FUNC_SEQ) +OF_PP_FOR_EACH_TUPLE(REGISTER_MATH_UNARY_ELEMENTWISE_OP_AND_GRAD, + MATH_UNARY_ELEMENTWISE_FUNC_SEQ_ODS) } // namespace oneflow diff --git a/oneflow/user/ops/math_unary_elementwise_seq.h b/oneflow/user/ops/math_unary_elementwise_seq.h index 9106397c0c8..db90cb9a9d3 100644 --- a/oneflow/user/ops/math_unary_elementwise_seq.h +++ b/oneflow/user/ops/math_unary_elementwise_seq.h @@ -56,6 +56,42 @@ namespace oneflow { OF_PP_MAKE_TUPLE_SEQ("square", Square) \ OF_PP_MAKE_TUPLE_SEQ("tan", Tan) +#define MATH_UNARY_ELEMENTWISE_FUNC_SEQ_ODS \ + OF_PP_MAKE_TUPLE_SEQ("abs", Abs) \ + OF_PP_MAKE_TUPLE_SEQ("acos", Acos) \ + OF_PP_MAKE_TUPLE_SEQ("acosh", Acosh) \ + OF_PP_MAKE_TUPLE_SEQ("asin", Asin) \ + OF_PP_MAKE_TUPLE_SEQ("asinh", Asinh) \ + OF_PP_MAKE_TUPLE_SEQ("atan", Atan) \ + OF_PP_MAKE_TUPLE_SEQ("atanh", Atanh) \ + OF_PP_MAKE_TUPLE_SEQ("ceil", Ceil) \ + OF_PP_MAKE_TUPLE_SEQ("cos", Cos) \ + OF_PP_MAKE_TUPLE_SEQ("cosh", Cosh) \ + OF_PP_MAKE_TUPLE_SEQ("erf", Erf) \ + OF_PP_MAKE_TUPLE_SEQ("erfc", Erfc) \ + OF_PP_MAKE_TUPLE_SEQ("exp", Exp) \ + OF_PP_MAKE_TUPLE_SEQ("expm1", Expm1) \ + OF_PP_MAKE_TUPLE_SEQ("floor", Floor) \ + OF_PP_MAKE_TUPLE_SEQ("lgamma", Lgamma) \ + OF_PP_MAKE_TUPLE_SEQ("log", Log) \ + OF_PP_MAKE_TUPLE_SEQ("log2", Log2) \ + OF_PP_MAKE_TUPLE_SEQ("log1p", Log1p) \ + OF_PP_MAKE_TUPLE_SEQ("log_sigmoid", LogSigmoid) \ + OF_PP_MAKE_TUPLE_SEQ("negative", Negative) \ + OF_PP_MAKE_TUPLE_SEQ("reciprocal", Reciprocal) \ + OF_PP_MAKE_TUPLE_SEQ("reciprocal_no_nan", ReciprocalNoNan) \ + OF_PP_MAKE_TUPLE_SEQ("rint", Rint) \ + OF_PP_MAKE_TUPLE_SEQ("round", Round) \ + OF_PP_MAKE_TUPLE_SEQ("rsqrt", Rsqrt) \ + OF_PP_MAKE_TUPLE_SEQ("sigmoid_v2", SigmoidV2) \ + OF_PP_MAKE_TUPLE_SEQ("sign", Sign) \ + OF_PP_MAKE_TUPLE_SEQ("sin", Sin) \ + OF_PP_MAKE_TUPLE_SEQ("sinh", Sinh) \ + OF_PP_MAKE_TUPLE_SEQ("softplus", Softplus) \ + OF_PP_MAKE_TUPLE_SEQ("sqrt", Sqrt) \ + OF_PP_MAKE_TUPLE_SEQ("square", Square) \ + OF_PP_MAKE_TUPLE_SEQ("tan", Tan) + } // namespace oneflow #endif // ONEFLOW_USER_OPS_MATH_UNARY_ELEMENTWISE_SEQ_H_ diff --git a/oneflow/user/ops/matmul_op.cpp b/oneflow/user/ops/matmul_op.cpp index b551617e595..4e73d7e3e35 100644 --- a/oneflow/user/ops/matmul_op.cpp +++ b/oneflow/user/ops/matmul_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -141,271 +142,278 @@ void GenBackwardOpConf4Matmul(const std::string& op_type_name, const user_op::Us } // namespace -REGISTER_USER_OP("matmul") - .Input("a") - .Input("b") - .OptionalInput("_add_to_output") - .Output("out") - .Attr("transpose_a", false) - .Attr("transpose_b", false) - .Attr("alpha", 1.0) - .SetTensorDescInferFn(InferTensorDesc4Matmul) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - // (m, k_a) * (k_b, n) where k_a == k_b - int32_t m_axis = -1; - int32_t k_a_axis = -1; - int32_t k_b_axis = -1; - int32_t n_axis = -1; - if (ctx->Attr("transpose_a")) { - m_axis = 1; - k_a_axis = 0; - } else { - m_axis = 0; - k_a_axis = 1; - } - if (ctx->Attr("transpose_b")) { - k_b_axis = 1; - n_axis = 0; - } else { - k_b_axis = 0; - n_axis = 1; - } - std::vector out_and_add_to_output_args; - out_and_add_to_output_args.emplace_back("out", 0); - if (ctx->user_op_conf().has_input("_add_to_output", 0)) { - out_and_add_to_output_args.emplace_back("_add_to_output", 0); - } - ctx->NewBuilder() - .Split(user_op::OpArg("a", 0), m_axis) - .Broadcast(user_op::OpArg("b", 0)) - .Split(out_and_add_to_output_args, 0) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("a", 0)) - .Split(user_op::OpArg("b", 0), n_axis) - .Split(out_and_add_to_output_args, 1) - .Build(); - ctx->NewBuilder() - .Split(user_op::OpArg("a", 0), k_a_axis) - .Split(user_op::OpArg("b", 0), k_b_axis) - .PartialSum(out_and_add_to_output_args) - .Build(); - ctx->NewBuilder() - .PartialSum(user_op::OpArg("a", 0)) - .Broadcast(user_op::OpArg("b", 0)) - .PartialSum(out_and_add_to_output_args) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("a", 0)) - .PartialSum(user_op::OpArg("b", 0)) - .PartialSum(out_and_add_to_output_args) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn(InferDataType4Matmul); +/* static */ Maybe MatmulOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDesc4Matmul(ctx); +} + +/*static*/ Maybe MatmulOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} -REGISTER_USER_OP_GRAD("matmul").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, - user_op::AddOpFn AddOp) -> Maybe { - GenBackwardOpConf4Matmul("matmul", op, AddOp); +/* static */ Maybe MatmulOp::GetSbp(user_op::SbpContext* ctx) { + // (m, k_a) * (k_b, n) where k_a == k_b + int32_t m_axis = -1; + int32_t k_a_axis = -1; + int32_t k_b_axis = -1; + int32_t n_axis = -1; + if (ctx->Attr("transpose_a")) { + m_axis = 1; + k_a_axis = 0; + } else { + m_axis = 0; + k_a_axis = 1; + } + if (ctx->Attr("transpose_b")) { + k_b_axis = 1; + n_axis = 0; + } else { + k_b_axis = 0; + n_axis = 1; + } + std::vector out_and_add_to_output_args; + out_and_add_to_output_args.emplace_back("out", 0); + if (ctx->user_op_conf().has_input("_add_to_output", 0)) { + out_and_add_to_output_args.emplace_back("_add_to_output", 0); + } + ctx->NewBuilder() + .Split(user_op::OpArg("a", 0), m_axis) + .Broadcast(user_op::OpArg("b", 0)) + .Split(out_and_add_to_output_args, 0) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("a", 0)) + .Split(user_op::OpArg("b", 0), n_axis) + .Split(out_and_add_to_output_args, 1) + .Build(); + ctx->NewBuilder() + .Split(user_op::OpArg("a", 0), k_a_axis) + .Split(user_op::OpArg("b", 0), k_b_axis) + .PartialSum(out_and_add_to_output_args) + .Build(); + ctx->NewBuilder() + .PartialSum(user_op::OpArg("a", 0)) + .Broadcast(user_op::OpArg("b", 0)) + .PartialSum(out_and_add_to_output_args) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("a", 0)) + .PartialSum(user_op::OpArg("b", 0)) + .PartialSum(out_and_add_to_output_args) + .Build(); return Maybe::Ok(); -}); - -REGISTER_USER_OP("batch_matmul") - .Input("a") - .Input("b") - .OptionalInput("_add_to_output") - .Output("out") - .Attr("transpose_a", false) - .Attr("transpose_b", false) - .Attr("alpha", 1.0) - .SetTensorDescInferFn(InferTensorDesc4Matmul) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& a_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0); - std::vector out_and_add_to_output_args; - out_and_add_to_output_args.emplace_back("out", 0); - if (ctx->user_op_conf().has_input("_add_to_output", 0)) { - out_and_add_to_output_args.emplace_back("_add_to_output", 0); - } - FOR_RANGE(int64_t, i, 0, a_tensor.shape().NumAxes() - 2) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(out_and_add_to_output_args, i).Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn(InferDataType4Matmul); +} -REGISTER_USER_OP_GRAD("batch_matmul") - .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, - user_op::AddOpFn AddOp) -> Maybe { - GenBackwardOpConf4Matmul("batch_matmul", op, AddOp); - return Maybe::Ok(); - }); +/* static */ Maybe MatmulOp::InferDataType(user_op::InferContext* ctx) { + return InferDataType4Matmul(ctx); +} -REGISTER_USER_OP("broadcast_matmul") - .Input("a") - .Input("b") - .OptionalInput("_add_to_output") - .Output("out") - .Attr("transpose_a", false) - .Attr("transpose_b", false) - .Attr("alpha", 1.0) - .SetDataTypeInferFn(InferDataType4Matmul) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - bool transpose_a = ctx->Attr("transpose_a"); - bool transpose_b = ctx->Attr("transpose_b"); - - const user_op::TensorDesc& a = ctx->InputTensorDesc("a", 0); - const user_op::TensorDesc& b = ctx->InputTensorDesc("b", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - - // NOTE: support broadcast b to a for now - // TODO(zwx): support broadcast a to b - CHECK_GT_OR_RETURN(a.shape().NumAxes(), b.shape().NumAxes()); - CHECK_EQ_OR_RETURN(b.shape().NumAxes(), 2); - // NOTE: don't support transpose_a for now - CHECK_OR_RETURN(!transpose_a); +/* static */ Maybe BatchMatmulOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDesc4Matmul(ctx); +} - DimVector out_dim_vec(a.shape().NumAxes() - 1); - FOR_RANGE(int64_t, i, 0, out_dim_vec.size()) { out_dim_vec[i] = a.shape().At(i); } - int64_t k = a.shape().At(a.shape().NumAxes() - 1); - int64_t n = -1; - if (!transpose_b) { - CHECK_EQ_OR_RETURN(k, b.shape().At(b.shape().NumAxes() - 2)); - n = b.shape().At(b.shape().NumAxes() - 1); - } else { - CHECK_EQ_OR_RETURN(k, b.shape().At(b.shape().NumAxes() - 1)); - n = b.shape().At(b.shape().NumAxes() - 2); - } - out_dim_vec.emplace_back(n); - *out->mut_shape() = Shape(out_dim_vec); - - if (ctx->has_input("_add_to_output", 0)) { - const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); - CHECK_EQ_OR_RETURN(add_to_output.shape(), out->shape()); - } +/*static*/ Maybe BatchMatmulOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - // (b, m, k) * (k, n) when transpose_b is false - // (b, m, k) * (n, k) when transpose_b is true - bool transpose_a = ctx->Attr("transpose_a"); - bool transpose_b = ctx->Attr("transpose_b"); - CHECK_OR_RETURN(!transpose_a); +/* static */ Maybe BatchMatmulOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& a_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0); + std::vector out_and_add_to_output_args; + out_and_add_to_output_args.emplace_back("out", 0); + if (ctx->user_op_conf().has_input("_add_to_output", 0)) { + out_and_add_to_output_args.emplace_back("_add_to_output", 0); + } + FOR_RANGE(int64_t, i, 0, a_tensor.shape().NumAxes() - 2) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(out_and_add_to_output_args, i).Build(); + } + return Maybe::Ok(); +} - const auto& a_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape(); - int32_t k_a_axis = a_shape.NumAxes() - 1; - int32_t k_b_axis = -1; - int32_t n_axis = -1; - if (transpose_b) { - k_b_axis = 1; - n_axis = 0; - } else { - k_b_axis = 0; - n_axis = 1; - } - - std::vector out_and_add_to_output_args; - out_and_add_to_output_args.emplace_back("out", 0); - if (ctx->user_op_conf().has_input("_add_to_output", 0)) { - out_and_add_to_output_args.emplace_back("_add_to_output", 0); - } - - // S(b or m axis) x B -> S(b or m axis) - for (int64_t i = 0; i < a_shape.NumAxes() - 1; ++i) { - ctx->NewBuilder() - .Split(user_op::OpArg("a", 0), i) - .Broadcast(user_op::OpArg("b", 0)) - .Split(out_and_add_to_output_args, i) - .Build(); - } - // B x S(n_axis) -> S(n_axis) - ctx->NewBuilder() - .Broadcast(user_op::OpArg("a", 0)) - .Split(user_op::OpArg("b", 0), n_axis) - .Split(out_and_add_to_output_args, a_shape.NumAxes() - 1) - .Build(); - // S(a_k_axis) x S(b_k_axis) -> P - ctx->NewBuilder() - .Split(user_op::OpArg("a", 0), k_a_axis) - .Split(user_op::OpArg("b", 0), k_b_axis) - .PartialSum(out_and_add_to_output_args) - .Build(); - // P x B -> P - ctx->NewBuilder() - .PartialSum(user_op::OpArg("a", 0)) - .Broadcast(user_op::OpArg("b", 0)) - .PartialSum(out_and_add_to_output_args) - .Build(); - // B x P -> P - ctx->NewBuilder() - .Broadcast(user_op::OpArg("a", 0)) - .PartialSum(user_op::OpArg("b", 0)) - .PartialSum(out_and_add_to_output_args) - .Build(); - return Maybe::Ok(); - }); +/* static */ Maybe BatchMatmulOp::InferDataType(user_op::InferContext* ctx) { + return InferDataType4Matmul(ctx); +} + +/* static */ Maybe BroadcastMatmulOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + bool transpose_a = ctx->Attr("transpose_a"); + bool transpose_b = ctx->Attr("transpose_b"); -REGISTER_USER_OP("broadcast_matmul_grad_b") - .Input("a") - .Input("b") - .OptionalInput("_add_to_output") - .Output("out") - .Attr("alpha", 1.0) - .SetDataTypeInferFn(InferDataType4Matmul) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& a = ctx->InputTensorDesc("a", 0); - const user_op::TensorDesc& b = ctx->InputTensorDesc("b", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - - CHECK_EQ_OR_RETURN(a.shape().NumAxes(), b.shape().NumAxes()); - for (int i = 0; i < a.shape().NumAxes() - 1; ++i) { - CHECK_EQ_OR_RETURN(a.shape().At(i), b.shape().At(i)); - } - - *out->mut_shape() = - Shape({a.shape().At(a.shape().NumAxes() - 1), b.shape().At(b.shape().NumAxes() - 1)}); - - if (ctx->has_input("_add_to_output", 0)) { - const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); - CHECK_EQ_OR_RETURN(add_to_output.shape(), out->shape()); - } + const user_op::TensorDesc& a = ctx->InputTensorDesc("a", 0); + const user_op::TensorDesc& b = ctx->InputTensorDesc("b", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + // NOTE: support broadcast b to a for now + // TODO(zwx): support broadcast a to b + CHECK_GT_OR_RETURN(a.shape().NumAxes(), b.shape().NumAxes()); + CHECK_EQ_OR_RETURN(b.shape().NumAxes(), 2); + // NOTE: don't support transpose_a for now + CHECK_OR_RETURN(!transpose_a); + + DimVector out_dim_vec(a.shape().NumAxes() - 1); + FOR_RANGE(int64_t, i, 0, out_dim_vec.size()) { out_dim_vec[i] = a.shape().At(i); } + int64_t k = a.shape().At(a.shape().NumAxes() - 1); + int64_t n = -1; + if (!transpose_b) { + CHECK_EQ_OR_RETURN(k, b.shape().At(b.shape().NumAxes() - 2)); + n = b.shape().At(b.shape().NumAxes() - 1); + } else { + CHECK_EQ_OR_RETURN(k, b.shape().At(b.shape().NumAxes() - 1)); + n = b.shape().At(b.shape().NumAxes() - 2); + } + out_dim_vec.emplace_back(n); + *out->mut_shape() = Shape(out_dim_vec); + + if (ctx->has_input("_add_to_output", 0)) { + const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); + CHECK_EQ_OR_RETURN(add_to_output.shape(), out->shape()); + } + + return Maybe::Ok(); +} + +/*static*/ Maybe BroadcastMatmulOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe BroadcastMatmulOp::GetSbp(user_op::SbpContext* ctx) { + // (b, m, k) * (k, n) when transpose_b is false + // (b, m, k) * (n, k) when transpose_b is true + bool transpose_a = ctx->Attr("transpose_a"); + bool transpose_b = ctx->Attr("transpose_b"); + CHECK_OR_RETURN(!transpose_a); + + const auto& a_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape(); + int32_t k_a_axis = a_shape.NumAxes() - 1; + int32_t k_b_axis = -1; + int32_t n_axis = -1; + if (transpose_b) { + k_b_axis = 1; + n_axis = 0; + } else { + k_b_axis = 0; + n_axis = 1; + } + + std::vector out_and_add_to_output_args; + out_and_add_to_output_args.emplace_back("out", 0); + if (ctx->user_op_conf().has_input("_add_to_output", 0)) { + out_and_add_to_output_args.emplace_back("_add_to_output", 0); + } + + // S(b or m axis) x B -> S(b or m axis) + for (int64_t i = 0; i < a_shape.NumAxes() - 1; ++i) { + ctx->NewBuilder() + .Split(user_op::OpArg("a", 0), i) + .Broadcast(user_op::OpArg("b", 0)) + .Split(out_and_add_to_output_args, i) + .Build(); + } + // B x S(n_axis) -> S(n_axis) + ctx->NewBuilder() + .Broadcast(user_op::OpArg("a", 0)) + .Split(user_op::OpArg("b", 0), n_axis) + .Split(out_and_add_to_output_args, a_shape.NumAxes() - 1) + .Build(); + // S(a_k_axis) x S(b_k_axis) -> P + ctx->NewBuilder() + .Split(user_op::OpArg("a", 0), k_a_axis) + .Split(user_op::OpArg("b", 0), k_b_axis) + .PartialSum(out_and_add_to_output_args) + .Build(); + // P x B -> P + ctx->NewBuilder() + .PartialSum(user_op::OpArg("a", 0)) + .Broadcast(user_op::OpArg("b", 0)) + .PartialSum(out_and_add_to_output_args) + .Build(); + // B x P -> P + ctx->NewBuilder() + .Broadcast(user_op::OpArg("a", 0)) + .PartialSum(user_op::OpArg("b", 0)) + .PartialSum(out_and_add_to_output_args) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe BroadcastMatmulOp::InferDataType(user_op::InferContext* ctx) { + return InferDataType4Matmul(ctx); +} + +/* static */ Maybe BroadcastMatmulGradBOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& a = ctx->InputTensorDesc("a", 0); + const user_op::TensorDesc& b = ctx->InputTensorDesc("b", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + + CHECK_EQ_OR_RETURN(a.shape().NumAxes(), b.shape().NumAxes()); + for (int i = 0; i < a.shape().NumAxes() - 1; ++i) { + CHECK_EQ_OR_RETURN(a.shape().At(i), b.shape().At(i)); + } + + *out->mut_shape() = + Shape({a.shape().At(a.shape().NumAxes() - 1), b.shape().At(b.shape().NumAxes() - 1)}); + + if (ctx->has_input("_add_to_output", 0)) { + const user_op::TensorDesc& add_to_output = ctx->InputTensorDesc("_add_to_output", 0); + CHECK_EQ_OR_RETURN(add_to_output.shape(), out->shape()); + } + + return Maybe::Ok(); +} + +/*static*/ Maybe BroadcastMatmulGradBOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe BroadcastMatmulGradBOp::GetSbp(user_op::SbpContext* ctx) { + const auto& a_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape(); + int64_t last_axis = a_shape.NumAxes() - 1; + + std::vector out_and_add_to_output_args; + out_and_add_to_output_args.emplace_back("out", 0); + if (ctx->user_op_conf().has_input("_add_to_output", 0)) { + out_and_add_to_output_args.emplace_back("_add_to_output", 0); + } + + // S(b or m axis) x S(b or m axis) -> P + for (int64_t i = 0; i < last_axis; ++i) { + ctx->NewBuilder() + .Split(user_op::OpArg("a", 0), i) + .Split(user_op::OpArg("b", 0), i) + .PartialSum(out_and_add_to_output_args) + .Build(); + } + + // (b, m, k) * (b, m, n) -> (k, n) [transpose a] + // S(k) x B -> S(0) or B x S(n) -> S(1) + // (b, m, n) * (b, m, k) -> (n, k) [transpose a] + // S(n) x B -> S(0) or B x S(k) -> S(1) + ctx->NewBuilder() + .Split(user_op::OpArg("a", 0), last_axis) + .Broadcast(user_op::OpArg("b", 0)) + .Split(out_and_add_to_output_args, 0) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("a", 0)) + .Split(user_op::OpArg("b", 0), last_axis) + .Split(out_and_add_to_output_args, 1) + .Build(); + + return Maybe::Ok(); +} + +/* static */ Maybe BroadcastMatmulGradBOp::InferDataType(user_op::InferContext* ctx) { + return InferDataType4Matmul(ctx); +} + +REGISTER_USER_OP_GRAD("matmul").SetGenBackwardOpConfFn( + [](const user_op::UserOpWrapper& op, const user_op::AddOpFn& AddOp) -> Maybe { + GenBackwardOpConf4Matmul("matmul", op, AddOp); return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto& a_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("a", 0).shape(); - int64_t last_axis = a_shape.NumAxes() - 1; - - std::vector out_and_add_to_output_args; - out_and_add_to_output_args.emplace_back("out", 0); - if (ctx->user_op_conf().has_input("_add_to_output", 0)) { - out_and_add_to_output_args.emplace_back("_add_to_output", 0); - } - - // S(b or m axis) x S(b or m axis) -> P - for (int64_t i = 0; i < last_axis; ++i) { - ctx->NewBuilder() - .Split(user_op::OpArg("a", 0), i) - .Split(user_op::OpArg("b", 0), i) - .PartialSum(out_and_add_to_output_args) - .Build(); - } - - // (b, m, k) * (b, m, n) -> (k, n) [transpose a] - // S(k) x B -> S(0) or B x S(n) -> S(1) - // (b, m, n) * (b, m, k) -> (n, k) [transpose a] - // S(n) x B -> S(0) or B x S(k) -> S(1) - ctx->NewBuilder() - .Split(user_op::OpArg("a", 0), last_axis) - .Broadcast(user_op::OpArg("b", 0)) - .Split(out_and_add_to_output_args, 0) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("a", 0)) - .Split(user_op::OpArg("b", 0), last_axis) - .Split(out_and_add_to_output_args, 1) - .Build(); + }); +REGISTER_USER_OP_GRAD("batch_matmul") + .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, + const user_op::AddOpFn& AddOp) -> Maybe { + GenBackwardOpConf4Matmul("batch_matmul", op, AddOp); return Maybe::Ok(); }); diff --git a/oneflow/user/ops/min_max_observer_op.cpp b/oneflow/user/ops/min_max_observer_op.cpp index d1003ba287f..3d7f186c378 100644 --- a/oneflow/user/ops/min_max_observer_op.cpp +++ b/oneflow/user/ops/min_max_observer_op.cpp @@ -14,73 +14,65 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/* static */ Maybe MinMaxObserverOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); -REGISTER_NO_GRAD_USER_OP("min_max_observer") - .Input("in") - .Output("scale") - .Output("zero_point") - // NOTE(Liang Depeng): "google" or "cambricon" - .Attr("quantization_formula", "google") - // NOTE(Liang Depeng): quantize from float32 to "quantization_bit" bit signed or unsigned - // integer - .Attr("quantization_bit", 8) - // NOTE(Liang Depeng): "symmetric" or "affine": quantize to signed or unsigned integer - .Attr("quantization_scheme", "symmetric") - // NOTE(Liang Depeng): "true" or "false": per-layer or per-channel quantization. - .Attr("per_layer_quantization", true) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); + if (ctx->Attr("quantization_formula") == "google") { + if (ctx->Attr("per_layer_quantization") == true) { + *ctx->OutputShape("scale", 0) = Shape({1}); + *ctx->OutputShape("zero_point", 0) = Shape({1}); + } else { + // NOTE(Liang Depeng): For now per-channel quantization only support axis 0 + *ctx->OutputShape("scale", 0) = Shape({in_shape.At(0)}); + *ctx->OutputShape("zero_point", 0) = Shape({in_shape.At(0)}); + } + } else { // quantization_formula == "cambricon" + *ctx->OutputShape("scale", 0) = Shape({1}); + *ctx->OutputShape("zero_point", 0) = Shape({1}); + } + return Maybe::Ok(); +} - if (ctx->Attr("quantization_formula") == "google") { - if (ctx->Attr("per_layer_quantization") == true) { - *ctx->OutputShape("scale", 0) = Shape({1}); - *ctx->OutputShape("zero_point", 0) = Shape({1}); - } else { - // NOTE(Liang Depeng): For now per-channel quantization only support axis 0 - *ctx->OutputShape("scale", 0) = Shape({in_shape.At(0)}); - *ctx->OutputShape("zero_point", 0) = Shape({in_shape.At(0)}); - } - } else { // quantization_formula == "cambricon" - *ctx->OutputShape("scale", 0) = Shape({1}); - *ctx->OutputShape("zero_point", 0) = Shape({1}); - } - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* in = GetInputArgModifierFn("in", 0); - CHECK_OR_RETURN(in != nullptr); - in->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - // NOTE(Liang Depeng): input needs to be broadcast in order to accurately calculate the - // global scale and zero_point - return Maybe::Ok(); - }) - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& op_def, - const user_op::UserOpConfWrapper& op_conf) -> Maybe { - int32_t quantization_bit = op_conf.attr("quantization_bit"); - CHECK_GT_OR_RETURN(quantization_bit, 1); - CHECK_LE_OR_RETURN(quantization_bit, 8); +/*static*/ Maybe MinMaxObserverOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} - std::string quantization_scheme = op_conf.attr("quantization_scheme"); - CHECK_OR_RETURN(quantization_scheme == "symmetric" || quantization_scheme == "affine"); +/* static */ Maybe MinMaxObserverOp::GetSbp(user_op::SbpContext* ctx) { + // NOTE(Liang Depeng): input needs to be broadcast in order to accurately calculate the + // global scale and zero_point + return Maybe::Ok(); +} - std::string quantization_formula = op_conf.attr("quantization_formula"); - CHECK_OR_RETURN(quantization_formula == "google" || quantization_formula == "cambricon"); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("scale", 0) = ctx->InputDType("in", 0); - *ctx->OutputDType("zero_point", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/* static */ Maybe MinMaxObserverOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* in = GetInputArgModifierFn("in", 0); + CHECK_OR_RETURN(in != nullptr); + in->set_requires_grad(false); + return Maybe::Ok(); +} -} // namespace +/* static */ Maybe MinMaxObserverOp::CheckAttr(const user_op::UserOpDefWrapper& def, + const user_op::UserOpConfWrapper& op_conf) { + int32_t quantization_bit = op_conf.attr("quantization_bit"); + CHECK_GT_OR_RETURN(quantization_bit, 1); + CHECK_LE_OR_RETURN(quantization_bit, 8); + + std::string quantization_scheme = op_conf.attr("quantization_scheme"); + CHECK_OR_RETURN(quantization_scheme == "symmetric" || quantization_scheme == "affine"); + + std::string quantization_formula = op_conf.attr("quantization_formula"); + CHECK_OR_RETURN(quantization_formula == "google" || quantization_formula == "cambricon"); + return Maybe::Ok(); +} + +/* static */ Maybe MinMaxObserverOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("scale", 0) = ctx->InputDType("in", 0); + *ctx->OutputDType("zero_point", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/mish_op.cpp b/oneflow/user/ops/mish_op.cpp index 4f51ca76034..9b3c04bf17d 100644 --- a/oneflow/user/ops/mish_op.cpp +++ b/oneflow/user/ops/mish_op.cpp @@ -14,61 +14,62 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/* static */ Maybe MishOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("mish") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe MishOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} -REGISTER_USER_OP("mish_grad") - .Input("x") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK(dy_shape == x_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/* static */ Maybe MishOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe MishOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe MishGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK(dy_shape == x_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} + +/*static*/ Maybe MishGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe MishGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe MishGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("mish").SetBackwardOpConfGenFn( [](user_op::BackwardOpConfContext* ctx) -> Maybe { @@ -87,6 +88,4 @@ REGISTER_USER_OP_GRAD("mish").SetBackwardOpConfGenFn( return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/model_update_ops.cpp b/oneflow/user/ops/model_update_ops.cpp index 5c842cea319..d0da22056f9 100644 --- a/oneflow/user/ops/model_update_ops.cpp +++ b/oneflow/user/ops/model_update_ops.cpp @@ -17,6 +17,7 @@ limitations under the License. #include "oneflow/core/framework/infer_util.h" #include "oneflow/core/framework/user_op_conf.h" #include "oneflow/core/framework/user_op_registry.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -427,361 +428,369 @@ Maybe InferLarsUpdateDataType(user_op::InferContext* ctx) { } return Maybe::Ok(); } -REGISTER_NO_GRAD_USER_OP("sgd_update") - .Input("model") - .Input("model_diff") - .OptionalInput("learning_rate") - .OptionalInput("scale_by_tensor") - .OptionalInput("skip_if") - .Attr("learning_rate_val", 0.0) - .Attr("scale", 1.0) - .Attr("l1", 0.0) - .Attr("l2", 0.0) - .Attr("weight_decay", 0.0) - .SetTensorDescInferFn(InferSGDUpdateTensorDesc) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); - FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { - ctx->NewBuilder() - .Broadcast(ctx->inputs()) - .Split(user_op::OpArg("model", 0), axis) - .Split(user_op::OpArg("model_diff", 0), axis) - .Build(); - } - return Maybe::Ok(); - }) - .SetInputArgModifyFn(SgdInputArgModifyFn) - .SetDataTypeInferFn(InferSGDUpdateDataType); - -REGISTER_NO_GRAD_USER_OP("indexed_slices_sgd_update") - .Input("model") - .Input("model_diff_indices") - .Input("model_diff_values") - .Input("learning_rate") - .Attr("weight_decay", 0.0) - .SetTensorDescInferFn(InferIndexedSlicesSGDUpdateTensorDesc) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); - const user_op::TensorDesc& model_diff_indices = - ctx->LogicalTensorDesc4InputArgNameAndIndex("model_diff_indices", 0); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("learning_rate", 0)) - .Broadcast(user_op::OpArg("model_diff_indices", 0)) - .Broadcast(user_op::OpArg("model_diff_values", 0)) - .Split(user_op::OpArg("model", 0), 0) - .Build(); - FOR_RANGE(int64_t, i, 1, model.shape().NumAxes()) { - ctx->NewBuilder() - .Broadcast(user_op::OpArg("learning_rate", 0)) - .Broadcast(user_op::OpArg("model_diff_indices", 0)) - .Split(user_op::OpArg("model_diff_values", 0), - model_diff_indices.shape().NumAxes() + i - 1) - .Split(user_op::OpArg("model", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetInputArgModifyFn(IndexedSlicesSgdInputArgModifyFn) - .SetDataTypeInferFn(InferIndexedSlicesSGDUpdateDataType); - -REGISTER_NO_GRAD_USER_OP("momentum_update") - .Input("model") - .Input("model_diff") - .Input("momentum") - .OptionalInput("learning_rate") - .OptionalInput("scale_by_tensor") - .OptionalInput("skip_if") - .Attr("learning_rate_val", 0.0) - .Attr("scale", 1.0) - .Attr("l1", 0.0) - .Attr("l2", 0.0) - .Attr("beta", 0.9) - .Attr("weight_decay", 0.0) - .SetTensorDescInferFn(InferMomentumUpdateTensorDesc) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); - FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { - ctx->NewBuilder() - .Broadcast(ctx->inputs()) - .Split(user_op::OpArg("model", 0), axis) - .Split(user_op::OpArg("model_diff", 0), axis) - .Split(user_op::OpArg("momentum", 0), axis) - .Build(); - } - return Maybe::Ok(); - }) - .SetInputArgModifyFn(MomentumInputArgModifyFn) - .SetDataTypeInferFn(InferMomentumUpdateDataType); - -REGISTER_NO_GRAD_USER_OP("indexed_slices_momentum_update") - .Input("model") - .Input("model_diff_indices") - .Input("model_diff_values") - .Input("learning_rate") - .Input("momentum") - .Attr("beta", 0.9) - .Attr("weight_decay", 0.0) - .SetTensorDescInferFn(InferIndexedSlicesMomentumUpdateTensorDesc) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); - const user_op::TensorDesc& model_diff_indices = - ctx->LogicalTensorDesc4InputArgNameAndIndex("model_diff_indices", 0); + +} // namespace + +/* static */ Maybe SgdUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferSGDUpdateTensorDesc(ctx); +} + +/*static*/ Maybe SgdUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe SgdUpdateOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); + FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { + ctx->NewBuilder() + .Broadcast(ctx->inputs()) + .Split(user_op::OpArg("model", 0), axis) + .Split(user_op::OpArg("model_diff", 0), axis) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe SgdUpdateOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return SgdInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe SgdUpdateOp::InferDataType(user_op::InferContext* ctx) { + return InferSGDUpdateDataType(ctx); +} + +/* static */ Maybe IndexedSlicesSgdUpdateOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return InferIndexedSlicesSGDUpdateTensorDesc(ctx); +} + +/*static*/ Maybe IndexedSlicesSgdUpdateOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe IndexedSlicesSgdUpdateOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); + const user_op::TensorDesc& model_diff_indices = + ctx->LogicalTensorDesc4InputArgNameAndIndex("model_diff_indices", 0); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("learning_rate", 0)) + .Broadcast(user_op::OpArg("model_diff_indices", 0)) + .Broadcast(user_op::OpArg("model_diff_values", 0)) + .Split(user_op::OpArg("model", 0), 0) + .Build(); + FOR_RANGE(int64_t, i, 1, model.shape().NumAxes()) { + ctx->NewBuilder() + .Broadcast(user_op::OpArg("learning_rate", 0)) + .Broadcast(user_op::OpArg("model_diff_indices", 0)) + .Split(user_op::OpArg("model_diff_values", 0), model_diff_indices.shape().NumAxes() + i - 1) + .Split(user_op::OpArg("model", 0), i) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe IndexedSlicesSgdUpdateOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return IndexedSlicesSgdInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe IndexedSlicesSgdUpdateOp::InferDataType(user_op::InferContext* ctx) { + return InferIndexedSlicesSGDUpdateDataType(ctx); +} + +/* static */ Maybe MomentumUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferMomentumUpdateTensorDesc(ctx); +} + +/*static*/ Maybe MomentumUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe MomentumUpdateOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); + FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { + ctx->NewBuilder() + .Broadcast(ctx->inputs()) + .Split(user_op::OpArg("model", 0), axis) + .Split(user_op::OpArg("model_diff", 0), axis) + .Split(user_op::OpArg("momentum", 0), axis) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe MomentumUpdateOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return MomentumInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe MomentumUpdateOp::InferDataType(user_op::InferContext* ctx) { + return InferMomentumUpdateDataType(ctx); +} + +/* static */ Maybe IndexedSlicesMomentumUpdateOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return InferIndexedSlicesMomentumUpdateTensorDesc(ctx); +} + +/*static*/ Maybe IndexedSlicesMomentumUpdateOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe IndexedSlicesMomentumUpdateOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); + const user_op::TensorDesc& model_diff_indices = + ctx->LogicalTensorDesc4InputArgNameAndIndex("model_diff_indices", 0); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("learning_rate", 0)) + .Broadcast(user_op::OpArg("model_diff_indices", 0)) + .Broadcast(user_op::OpArg("model_diff_values", 0)) + .Split(user_op::OpArg("model", 0), 0) + .Split(user_op::OpArg("momentum", 0), 0) + .Build(); + FOR_RANGE(int64_t, i, 1, model.shape().NumAxes()) { + ctx->NewBuilder() + .Broadcast(user_op::OpArg("learning_rate", 0)) + .Broadcast(user_op::OpArg("model_diff_indices", 0)) + .Split(user_op::OpArg("model_diff_values", 0), model_diff_indices.shape().NumAxes() + i - 1) + .Split(user_op::OpArg("model", 0), i) + .Split(user_op::OpArg("momentum", 0), i) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe IndexedSlicesMomentumUpdateOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return IndexedSlicesMomentumInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe IndexedSlicesMomentumUpdateOp::InferDataType(user_op::InferContext* ctx) { + return InferIndexedSlicesMomentumUpdateDataType(ctx); +} + +/* static */ Maybe AdamUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferAdamUpdateTensorDesc(ctx); +} + +/*static*/ Maybe AdamUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe AdamUpdateOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); + FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { + ctx->NewBuilder() + .Broadcast(ctx->inputs()) + .Split(user_op::OpArg("model", 0), axis) + .Split(user_op::OpArg("model_diff", 0), axis) + .Split(user_op::OpArg("m", 0), axis) + .Split(user_op::OpArg("v", 0), axis) + .Split(user_op::OpArg("max_v", 0), axis) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe AdamUpdateOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return AdamInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe AdamUpdateOp::InferDataType(user_op::InferContext* ctx) { + return InferAdamUpdateDataType(ctx); +} + +/* static */ Maybe AdagradUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferAdagradUpdateTensorDesc(ctx); +} + +/*static*/ Maybe AdagradUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe AdagradUpdateOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); + FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { + ctx->NewBuilder() + .Broadcast(ctx->inputs()) + .Split(user_op::OpArg("model", 0), axis) + .Split(user_op::OpArg("model_diff", 0), axis) + .Split(user_op::OpArg("sum", 0), axis) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe AdagradUpdateOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return AdagradInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe AdagradUpdateOp::InferDataType(user_op::InferContext* ctx) { + return InferAdagradUpdateDataType(ctx); +} + +/* static */ Maybe IndexedSlicesAdamUpdateOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return InferIndexedSlicesAdamUpdateTensorDesc(ctx); +} + +/*static*/ Maybe IndexedSlicesAdamUpdateOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe IndexedSlicesAdamUpdateOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); + const user_op::TensorDesc& model_diff_indices = + ctx->LogicalTensorDesc4InputArgNameAndIndex("model_diff_indices", 0); + std::vector broadcast_args; + broadcast_args.emplace_back("learning_rate", 0); + broadcast_args.emplace_back("model_diff_indices", 0); + ctx->NewBuilder() + .Broadcast(broadcast_args) + .Broadcast(user_op::OpArg("model_diff_values", 0)) + .Split(user_op::OpArg("model", 0), 0) + .Split(user_op::OpArg("m", 0), 0) + .Split(user_op::OpArg("v", 0), 0) + .Split(user_op::OpArg("max_v", 0), 0) + .Build(); + FOR_RANGE(int64_t, i, 1, model.shape().NumAxes()) { + ctx->NewBuilder() + .Broadcast(broadcast_args) + .Split(user_op::OpArg("model_diff_values", 0), model_diff_indices.shape().NumAxes() + i - 1) + .Split(user_op::OpArg("model", 0), i) + .Split(user_op::OpArg("m", 0), i) + .Split(user_op::OpArg("v", 0), i) + .Split(user_op::OpArg("max_v", 0), i) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe IndexedSlicesAdamUpdateOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return AdamInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe IndexedSlicesAdamUpdateOp::InferDataType(user_op::InferContext* ctx) { + return InferIndexedSlicesAdamUpdateDataType(ctx); +} + +/* static */ Maybe LambUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferLambUpdateTensorDesc(ctx); +} + +/*static*/ Maybe LambUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe LambUpdateOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe LambUpdateOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return LambInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe LambUpdateOp::InferDataType(user_op::InferContext* ctx) { + return InferLambUpdateDataType(ctx); +} + +/* static */ Maybe AdamBiasCorrectionFactorOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("train_step", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe AdamBiasCorrectionFactorOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe AdamBiasCorrectionFactorOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe AdamBiasCorrectionFactorOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = DataType::kFloat; + return Maybe::Ok(); +} + +/* static */ Maybe RmspropUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferRmsPropUpdateTensorDesc(ctx); +} + +/*static*/ Maybe RmspropUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe RmspropUpdateOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); + bool centered = ctx->Attr("centered"); + FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { + if (centered) { ctx->NewBuilder() - .Broadcast(user_op::OpArg("learning_rate", 0)) - .Broadcast(user_op::OpArg("model_diff_indices", 0)) - .Broadcast(user_op::OpArg("model_diff_values", 0)) - .Split(user_op::OpArg("model", 0), 0) - .Split(user_op::OpArg("momentum", 0), 0) + .Broadcast(ctx->inputs()) + .Split(user_op::OpArg("model", 0), axis) + .Split(user_op::OpArg("model_diff", 0), axis) + .Split(user_op::OpArg("mean_square", 0), axis) + .Split(user_op::OpArg("mean_gradient", 0), axis) .Build(); - FOR_RANGE(int64_t, i, 1, model.shape().NumAxes()) { - ctx->NewBuilder() - .Broadcast(user_op::OpArg("learning_rate", 0)) - .Broadcast(user_op::OpArg("model_diff_indices", 0)) - .Split(user_op::OpArg("model_diff_values", 0), - model_diff_indices.shape().NumAxes() + i - 1) - .Split(user_op::OpArg("model", 0), i) - .Split(user_op::OpArg("momentum", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetInputArgModifyFn(IndexedSlicesMomentumInputArgModifyFn) - .SetDataTypeInferFn(InferIndexedSlicesMomentumUpdateDataType); - -REGISTER_NO_GRAD_USER_OP("adam_update") - .Input("model") - .Input("model_diff") - .OptionalInput("learning_rate") - .OptionalInput("scale_by_tensor") - .OptionalInput("skip_if") - .OptionalInput("bias_correction1") - .OptionalInput("bias_correction2") - .Input("m") - .Input("v") - .Input("max_v") - .Attr("learning_rate_val", 0.0) - .Attr("bias_correction1_val", 1.0) - .Attr("bias_correction2_val", 1.0) - .Attr("scale", 1.0) - .Attr("l1", 0.0) - .Attr("l2", 0.0) - .Attr("beta1", 0.9) - .Attr("beta2", 0.999) - .Attr("epsilon", 1e-8) - .Attr("weight_decay", 0.0) - .Attr("amsgrad", false) - .Attr("do_bias_correction", true) - .SetTensorDescInferFn(InferAdamUpdateTensorDesc) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); - FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { - ctx->NewBuilder() - .Broadcast(ctx->inputs()) - .Split(user_op::OpArg("model", 0), axis) - .Split(user_op::OpArg("model_diff", 0), axis) - .Split(user_op::OpArg("m", 0), axis) - .Split(user_op::OpArg("v", 0), axis) - .Split(user_op::OpArg("max_v", 0), axis) - .Build(); - } - return Maybe::Ok(); - }) - .SetInputArgModifyFn(AdamInputArgModifyFn) - .SetDataTypeInferFn(InferAdamUpdateDataType); - -REGISTER_NO_GRAD_USER_OP("adagrad_update") - .Input("model") - .Input("model_diff") - .OptionalInput("learning_rate") - .OptionalInput("scale_by_tensor") - .OptionalInput("skip_if") - .OptionalInput("train_step") - .Input("sum") - .Attr("train_step_val", 0) - .Attr("learning_rate_val", 0.0) - .Attr("scale", 1.0) - .Attr("l1", 0.0) - .Attr("l2", 0.0) - .Attr("lr_decay", 0.0) - .Attr("weight_decay", 0.0) - .Attr("epsilon", 1e-10) - .SetTensorDescInferFn(InferAdagradUpdateTensorDesc) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); - FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { - ctx->NewBuilder() - .Broadcast(ctx->inputs()) - .Split(user_op::OpArg("model", 0), axis) - .Split(user_op::OpArg("model_diff", 0), axis) - .Split(user_op::OpArg("sum", 0), axis) - .Build(); - } - return Maybe::Ok(); - }) - .SetInputArgModifyFn(AdagradInputArgModifyFn) - .SetDataTypeInferFn(InferAdagradUpdateDataType); - -REGISTER_NO_GRAD_USER_OP("indexed_slices_adam_update") - .Input("model") - .Input("model_diff_indices") - .Input("model_diff_values") - .Input("learning_rate") - .OptionalInput("bias_correction1") - .OptionalInput("bias_correction2") - .Input("m") - .Input("v") - .Input("max_v") - .Attr("learning_rate_val", 0.0) - .Attr("beta1", 0.9) - .Attr("beta2", 0.999) - .Attr("epsilon", 1e-8) - .Attr("weight_decay", 0.0) - .Attr("amsgrad", false) - .Attr("do_bias_correction", true) - .SetTensorDescInferFn(InferIndexedSlicesAdamUpdateTensorDesc) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); - const user_op::TensorDesc& model_diff_indices = - ctx->LogicalTensorDesc4InputArgNameAndIndex("model_diff_indices", 0); - std::vector broadcast_args; - broadcast_args.emplace_back("learning_rate", 0); - broadcast_args.emplace_back("model_diff_indices", 0); + } else { ctx->NewBuilder() - .Broadcast(broadcast_args) - .Broadcast(user_op::OpArg("model_diff_values", 0)) - .Split(user_op::OpArg("model", 0), 0) - .Split(user_op::OpArg("m", 0), 0) - .Split(user_op::OpArg("v", 0), 0) - .Split(user_op::OpArg("max_v", 0), 0) + .Broadcast(ctx->inputs()) + .Split(user_op::OpArg("model", 0), axis) + .Split(user_op::OpArg("model_diff", 0), axis) + .Split(user_op::OpArg("mean_square", 0), axis) .Build(); - FOR_RANGE(int64_t, i, 1, model.shape().NumAxes()) { - ctx->NewBuilder() - .Broadcast(broadcast_args) - .Split(user_op::OpArg("model_diff_values", 0), - model_diff_indices.shape().NumAxes() + i - 1) - .Split(user_op::OpArg("model", 0), i) - .Split(user_op::OpArg("m", 0), i) - .Split(user_op::OpArg("v", 0), i) - .Split(user_op::OpArg("max_v", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetInputArgModifyFn(AdamInputArgModifyFn) - .SetDataTypeInferFn(InferIndexedSlicesAdamUpdateDataType); - -REGISTER_NO_GRAD_USER_OP("lamb_update") - .Input("m") - .Input("v") - .Input("beta1_t") - .Input("beta2_t") - .Input("model") - .Input("model_diff") - .Input("learning_rate") - .OptionalInput("scale_by_tensor") - .OptionalInput("skip_if") - .Attr("beta1") - .Attr("beta2") - .Attr("epsilon") - .Attr("scale", 1.0) - .Attr("l1", 0.0) - .Attr("l2", 0.0) - .Attr("weight_decay", 0.0) - .SetTensorDescInferFn(InferLambUpdateTensorDesc) - // every bn has sbp broadcast signature - .SetInputArgModifyFn(LambInputArgModifyFn) - .SetDataTypeInferFn(InferLambUpdateDataType) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_USER_OP("adam_bias_correction_factor") - .Input("train_step") - .Output("out") - .Attr("beta", 0.9) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("train_step", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = DataType::kFloat; - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -// every bn has sbp broadcast signature - -REGISTER_NO_GRAD_USER_OP("rmsprop_update") - .Input("model") - .Input("model_diff") - .OptionalInput("learning_rate") - .OptionalInput("scale_by_tensor") - .OptionalInput("skip_if") - .Input("mean_square") - .OptionalInput("mean_gradient") - .Attr("learning_rate_val", 0.0) - .Attr("scale", 1.0) - .Attr("l1", 0.0) - .Attr("l2", 0.0) - .Attr("centered", false) - .Attr("epsilon", 1e-8) - .Attr("decay_rate", 0.99) - .Attr("weight_decay", 0.0) - .SetTensorDescInferFn(InferRmsPropUpdateTensorDesc) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); - bool centered = ctx->Attr("centered"); - FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { - if (centered) { - ctx->NewBuilder() - .Broadcast(ctx->inputs()) - .Split(user_op::OpArg("model", 0), axis) - .Split(user_op::OpArg("model_diff", 0), axis) - .Split(user_op::OpArg("mean_square", 0), axis) - .Split(user_op::OpArg("mean_gradient", 0), axis) - .Build(); - } else { - ctx->NewBuilder() - .Broadcast(ctx->inputs()) - .Split(user_op::OpArg("model", 0), axis) - .Split(user_op::OpArg("model_diff", 0), axis) - .Split(user_op::OpArg("mean_square", 0), axis) - .Build(); - } - } - return Maybe::Ok(); - }) - .SetInputArgModifyFn(RmsPropUpdateInputArgModifyFn) - .SetDataTypeInferFn(InferRmsPropUpdateDataType); - -REGISTER_NO_GRAD_USER_OP("lars_update") - .Input("model") - .Input("model_diff") - .Input("learning_rate") - .Input("momentum") - .OptionalInput("scale_by_tensor") - .OptionalInput("skip_if") - .Attr("scale", 1.0) - .Attr("l1", 0.0) - .Attr("l2", 0.0) - .Attr("momentum_beta", 0.9) - .Attr("epsilon", 1e-9) - .Attr("lars_coefficient", 1e-4) - .Attr("weight_decay", 0.0) - .SetTensorDescInferFn(InferLarsUpdateTensorDesc) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); - FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { - ctx->NewBuilder() - .Broadcast(ctx->inputs()) - .Split(user_op::OpArg("model", 0), axis) - .Split(user_op::OpArg("model_diff", 0), axis) - .Split(user_op::OpArg("momentum", 0), axis) - .Build(); - } - return Maybe::Ok(); - }) - .SetInputArgModifyFn(LarsUpdateInputArgModifyFn) - .SetDataTypeInferFn(InferLarsUpdateDataType); + } + } + return Maybe::Ok(); +} -} // namespace +/* static */ Maybe RmspropUpdateOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return RmsPropUpdateInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe RmspropUpdateOp::InferDataType(user_op::InferContext* ctx) { + return InferRmsPropUpdateDataType(ctx); +} + +/* static */ Maybe LarsUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferLarsUpdateTensorDesc(ctx); +} + +/*static*/ Maybe LarsUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe LarsUpdateOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& model = ctx->LogicalTensorDesc4InputArgNameAndIndex("model", 0); + FOR_RANGE(int64_t, axis, 0, model.shape().NumAxes()) { + ctx->NewBuilder() + .Broadcast(ctx->inputs()) + .Split(user_op::OpArg("model", 0), axis) + .Split(user_op::OpArg("model_diff", 0), axis) + .Split(user_op::OpArg("momentum", 0), axis) + .Build(); + } + return Maybe::Ok(); +} + +/* static */ Maybe LarsUpdateOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return LarsUpdateInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe LarsUpdateOp::InferDataType(user_op::InferContext* ctx) { + return InferLarsUpdateDataType(ctx); +} } // namespace oneflow diff --git a/oneflow/user/ops/moving_average_min_max_observer_op.cpp b/oneflow/user/ops/moving_average_min_max_observer_op.cpp index 8c4c59dc8e1..434865f2d59 100644 --- a/oneflow/user/ops/moving_average_min_max_observer_op.cpp +++ b/oneflow/user/ops/moving_average_min_max_observer_op.cpp @@ -14,94 +14,82 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -REGISTER_NO_GRAD_USER_OP("moving_average_min_max_observer") - .Input("in") - .Input("current_train_step") - .Input("moving_max") // NOTE(Liang Depeng): needs to be initialized as 0 - .Input("moving_min") // NOTE(Liang Depeng): needs to be initialized as 0 - .Output("scale") - .Output("zero_point") - .Attr("training") - // NOTE(Liang Depeng): "google" or "cambricon" - .Attr("quantization_formula", "google") - .Attr("stop_update_after_iters") - // NOTE(Liang Depeng): quantize from float32 to "quantization_bit" bit signed or unsigned - // integer - .Attr("quantization_bit", 8) - // NOTE(Liang Depeng): "symmetric" or "affine": quantize to signed or unsigned integer - .Attr("quantization_scheme", "symmetric") - // NOTE(Liang Depeng): smoothing parameter for exponential moving average operation - .Attr("momentum", 0.95) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& moving_max_shape = ctx->InputShape("moving_max", 0); - const Shape& moving_min_shape = ctx->InputShape("moving_min", 0); - const Shape& current_train_step = ctx->InputShape("current_train_step", 0); - - // NOTE(Liang Depeng): for now only support per-layer quantization - // TODO(Liang Depeng): depthwise convolution support per-channel quantization - CHECK_OR_RETURN(moving_max_shape.NumAxes() == 1 && moving_max_shape.At(0) == 1); - CHECK_OR_RETURN(moving_min_shape.NumAxes() == 1 && moving_min_shape.At(0) == 1); - - CHECK_OR_RETURN(current_train_step.NumAxes() == 1 && current_train_step.At(0) == 1); - - *ctx->OutputShape("scale", 0) = Shape({1}); - *ctx->OutputShape("zero_point", 0) = Shape({1}); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("scale", 0) = ctx->InputDType("in", 0); - *ctx->OutputDType("zero_point", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* in = GetInputArgModifierFn("in", 0); - CHECK_OR_RETURN(in != nullptr); - in->set_requires_grad(false); - - user_op::InputArgModifier* current_train_step = - GetInputArgModifierFn("current_train_step", 0); - CHECK_OR_RETURN(current_train_step != nullptr); - current_train_step->set_requires_grad(false); - - user_op::InputArgModifier* moving_max = GetInputArgModifierFn("moving_max", 0); - CHECK_OR_RETURN(moving_max != nullptr); - moving_max->set_requires_grad(false); - moving_max->set_is_mutable(true); - - user_op::InputArgModifier* moving_min = GetInputArgModifierFn("moving_min", 0); - CHECK_OR_RETURN(moving_min != nullptr); - moving_min->set_requires_grad(false); - moving_min->set_is_mutable(true); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - // NOTE(Liang Depeng): all inputs need to be broadcast in order to accuratly calculate the - // global scale and zero_point - return Maybe::Ok(); - }) - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& op_def, - const user_op::UserOpConfWrapper& op_conf) -> Maybe { - int32_t quantization_bit = op_conf.attr("quantization_bit"); - CHECK_GT_OR_RETURN(quantization_bit, 1); - CHECK_LE_OR_RETURN(quantization_bit, 8); - - std::string quantization_scheme = op_conf.attr("quantization_scheme"); - CHECK_OR_RETURN(quantization_scheme == "symmetric" || quantization_scheme == "affine"); - - int64_t stop_update_after_iters = op_conf.attr("stop_update_after_iters"); - CHECK_GT_OR_RETURN(stop_update_after_iters, 0); - - std::string quantization_formula = op_conf.attr("quantization_formula"); - CHECK_OR_RETURN(quantization_formula == "google" || quantization_formula == "cambricon"); - return Maybe::Ok(); - }); - -} // namespace +/* static */ Maybe MovingAverageMinMaxObserverOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const Shape& moving_max_shape = ctx->InputShape("moving_max", 0); + const Shape& moving_min_shape = ctx->InputShape("moving_min", 0); + const Shape& current_train_step = ctx->InputShape("current_train_step", 0); + + // NOTE(Liang Depeng): for now only support per-layer quantization + // TODO(Liang Depeng): depthwise convolution support per-channel quantization + CHECK_OR_RETURN(moving_max_shape.NumAxes() == 1 && moving_max_shape.At(0) == 1); + CHECK_OR_RETURN(moving_min_shape.NumAxes() == 1 && moving_min_shape.At(0) == 1); + + CHECK_OR_RETURN(current_train_step.NumAxes() == 1 && current_train_step.At(0) == 1); + + *ctx->OutputShape("scale", 0) = Shape({1}); + *ctx->OutputShape("zero_point", 0) = Shape({1}); + return Maybe::Ok(); +} + +/*static*/ Maybe MovingAverageMinMaxObserverOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe MovingAverageMinMaxObserverOp::GetSbp(user_op::SbpContext* ctx) { + // NOTE(Liang Depeng): all inputs need to be broadcast in order to accuratly calculate the + // global scale and zero_point + return Maybe::Ok(); +} + +/* static */ Maybe MovingAverageMinMaxObserverOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* in = GetInputArgModifierFn("in", 0); + CHECK_OR_RETURN(in != nullptr); + in->set_requires_grad(false); + + user_op::InputArgModifier* current_train_step = GetInputArgModifierFn("current_train_step", 0); + CHECK_OR_RETURN(current_train_step != nullptr); + current_train_step->set_requires_grad(false); + + user_op::InputArgModifier* moving_max = GetInputArgModifierFn("moving_max", 0); + CHECK_OR_RETURN(moving_max != nullptr); + moving_max->set_requires_grad(false); + moving_max->set_is_mutable(true); + + user_op::InputArgModifier* moving_min = GetInputArgModifierFn("moving_min", 0); + CHECK_OR_RETURN(moving_min != nullptr); + moving_min->set_requires_grad(false); + moving_min->set_is_mutable(true); + return Maybe::Ok(); +} + +/* static */ Maybe MovingAverageMinMaxObserverOp::CheckAttr( + const user_op::UserOpDefWrapper& def, const user_op::UserOpConfWrapper& op_conf) { + int32_t quantization_bit = op_conf.attr("quantization_bit"); + CHECK_GT_OR_RETURN(quantization_bit, 1); + CHECK_LE_OR_RETURN(quantization_bit, 8); + + std::string quantization_scheme = op_conf.attr("quantization_scheme"); + CHECK_OR_RETURN(quantization_scheme == "symmetric" || quantization_scheme == "affine"); + + int64_t stop_update_after_iters = op_conf.attr("stop_update_after_iters"); + CHECK_GT_OR_RETURN(stop_update_after_iters, 0); + + std::string quantization_formula = op_conf.attr("quantization_formula"); + CHECK_OR_RETURN(quantization_formula == "google" || quantization_formula == "cambricon"); + return Maybe::Ok(); +} + +/* static */ Maybe MovingAverageMinMaxObserverOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("scale", 0) = ctx->InputDType("in", 0); + *ctx->OutputDType("zero_point", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/multiply_op.cpp b/oneflow/user/ops/multiply_op.cpp index 59e45557025..18d6fa26a44 100644 --- a/oneflow/user/ops/multiply_op.cpp +++ b/oneflow/user/ops/multiply_op.cpp @@ -14,48 +14,51 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("multiply") - .Input("x") - .Input("y") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& y = ctx->InputTensorDesc("y", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(x.shape() == y.shape()); - *out->mut_shape() = x.shape(); - *out->mut_is_dynamic() = x.is_dynamic(); - if (x.is_dynamic() || y.is_dynamic()) { *out->mut_is_dynamic() = true; } - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x.shape().NumAxes()) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("x", 0)) - .Broadcast(user_op::OpArg("y", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("x", 0)) - .PartialSum(user_op::OpArg("y", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& y = ctx->InputTensorDesc("y", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(x.data_type() == y.data_type()); - *out->mut_data_type() = x.data_type(); - return Maybe::Ok(); - }); +/* static */ Maybe MultiplyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& y = ctx->InputTensorDesc("y", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(x.shape() == y.shape()); + *out->mut_shape() = x.shape(); + *out->mut_is_dynamic() = x.is_dynamic(); + if (x.is_dynamic() || y.is_dynamic()) { *out->mut_is_dynamic() = true; } + return Maybe::Ok(); +} + +/*static*/ Maybe MultiplyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe MultiplyOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x.shape().NumAxes()) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("x", 0)) + .Broadcast(user_op::OpArg("y", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("x", 0)) + .PartialSum(user_op::OpArg("y", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe MultiplyOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& y = ctx->InputTensorDesc("y", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(x.data_type() == y.data_type()); + *out->mut_data_type() = x.data_type(); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("multiply") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/narrow_op.cpp b/oneflow/user/ops/narrow_op.cpp index 0ca17e284ef..aebfd5a9262 100644 --- a/oneflow/user/ops/narrow_op.cpp +++ b/oneflow/user/ops/narrow_op.cpp @@ -14,125 +14,125 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("narrow") - .Input("in") - .Output("out") - .Attr("dim") - .Attr("start") - .Attr("length") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - CHECK_GT_OR_RETURN(in.shape().NumAxes(), 0); - const int64_t& dim = ctx->Attr("dim"); - const int64_t& start = ctx->Attr("start"); - const int64_t& length = ctx->Attr("length"); - CHECK_GE_OR_RETURN(dim, 0); - CHECK_GE_OR_RETURN(start, 0); - CHECK_GE_OR_RETURN(length, 0); - CHECK_GE_OR_RETURN(in.shape().At(dim), start + length); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - - DimVector dim_vec; - dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin(), - in.shape().dim_vec().cbegin() + dim); - dim_vec.insert(dim_vec.end(), length); - dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin() + dim + 1, - in.shape().dim_vec().end()); - *out->mut_shape() = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - const int64_t& dim = ctx->Attr("dim"); - const int64_t& length = ctx->Attr("length"); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - if (i != dim) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } else { - if (length == in_tensor.shape().At(i)) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - } - } +/* static */ Maybe NarrowOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + CHECK_GT_OR_RETURN(in.shape().NumAxes(), 0); + const int64_t& dim = ctx->Attr("dim"); + const int64_t& start = ctx->Attr("start"); + const int64_t& length = ctx->Attr("length"); + CHECK_GE_OR_RETURN(dim, 0); + CHECK_GE_OR_RETURN(start, 0); + CHECK_GE_OR_RETURN(length, 0); + CHECK_GE_OR_RETURN(in.shape().At(dim), start + length); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + + DimVector dim_vec; + dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin(), in.shape().dim_vec().cbegin() + dim); + dim_vec.insert(dim_vec.end(), length); + dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin() + dim + 1, + in.shape().dim_vec().end()); + *out->mut_shape() = Shape(dim_vec); + return Maybe::Ok(); +} + +/*static*/ Maybe NarrowOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe NarrowOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + const int64_t& dim = ctx->Attr("dim"); + const int64_t& length = ctx->Attr("length"); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + if (i != dim) { ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) + .Split(user_op::OpArg("in", 0), i) + .Split(user_op::OpArg("out", 0), i) .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - *out->mut_data_type() = in.data_type(); - return Maybe::Ok(); - }); - -REGISTER_USER_OP("narrow_grad") - .Input("dy") - .Input("like") - .Output("dx") - .Attr("dim") - .Attr("start") - .Attr("length") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& like_shape = ctx->InputShape("like", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - const int64_t ndim = dy_shape.NumAxes(); - CHECK_EQ_OR_RETURN(like_shape.NumAxes(), ndim); - - *ctx->OutputShape("dx", 0) = like_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const Shape& like_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape(); - const int64_t ndim = like_shape.NumAxes(); - const int64_t& dim = ctx->Attr("dim"); - const int64_t& length = ctx->Attr("length"); - FOR_RANGE(int64_t, i, 0, ndim) { - if (i != dim) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } else { - if (length == like_shape.At(i)) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - } + } else { + if (length == in_tensor.shape().At(i)) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); } - ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); - ctx->NewBuilder() - .PartialSum(user_op::OpArg("dy", 0)) - .Broadcast(user_op::OpArg("like", 0)) - .PartialSum(user_op::OpArg("dx", 0)) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("dy", 0)) - .PartialSum(user_op::OpArg("like", 0)) - .Broadcast(user_op::OpArg("dx", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> Maybe { - user_op::InputArgModifier* dy_modifier = GetInputArgModifierFn("dy", 0); - CHECK_NOTNULL_OR_RETURN(dy_modifier); - dy_modifier->set_requires_grad(false); - user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); - CHECK_NOTNULL_OR_RETURN(like_modifier); - like_modifier->set_requires_grad(false); - return Maybe::Ok(); - }); + } + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe NarrowOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + *out->mut_data_type() = in.data_type(); + return Maybe::Ok(); +} + +/* static */ Maybe NarrowGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& like_shape = ctx->InputShape("like", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + const int64_t ndim = dy_shape.NumAxes(); + CHECK_EQ_OR_RETURN(like_shape.NumAxes(), ndim); + + *ctx->OutputShape("dx", 0) = like_shape; + return Maybe::Ok(); +} + +/*static*/ Maybe NarrowGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe NarrowGradOp::GetSbp(user_op::SbpContext* ctx) { + const Shape& like_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape(); + const int64_t ndim = like_shape.NumAxes(); + const int64_t& dim = ctx->Attr("dim"); + const int64_t& length = ctx->Attr("length"); + FOR_RANGE(int64_t, i, 0, ndim) { + if (i != dim) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } else { + if (length == like_shape.At(i)) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + } + } + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); + ctx->NewBuilder() + .PartialSum(user_op::OpArg("dy", 0)) + .Broadcast(user_op::OpArg("like", 0)) + .PartialSum(user_op::OpArg("dx", 0)) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("dy", 0)) + .PartialSum(user_op::OpArg("like", 0)) + .Broadcast(user_op::OpArg("dx", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe NarrowGradOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* dy_modifier = GetInputArgModifierFn("dy", 0); + CHECK_NOTNULL_OR_RETURN(dy_modifier); + dy_modifier->set_requires_grad(false); + user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); + CHECK_NOTNULL_OR_RETURN(like_modifier); + like_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe NarrowGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("narrow").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp b/oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp index 397aea1f05b..a061187164e 100644 --- a/oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp +++ b/oneflow/user/ops/nccl_logical_2d_sbp_ops.cpp @@ -16,200 +16,244 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/operator.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("_nccl_logical_2D_same_dim0_all_reduce") - .Input("in") - .Output("out") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), 2); - CHECK_OR_RETURN(in_dis_hint.sbp_parallel(1).has_partial_sum_parallel()); - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_EQ_OR_RETURN(parallel_hierarchy.NumAxes(), 2); - - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - in_distribution->clear_sbp_parallel(); - out_distribution->clear_sbp_parallel(); - // in use hint - in_distribution->CopyFrom(in_dis_hint); - - // out dim0 use hint - *out_distribution->add_sbp_parallel() = in_dis_hint.sbp_parallel(0); - // out dim1 = broadcast - out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); - - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_USER_OP("_nccl_logical_2D_same_dim1_all_reduce") - .Input("in") - .Output("out") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), 2); - CHECK_OR_RETURN(in_dis_hint.sbp_parallel(0).has_partial_sum_parallel()); - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_EQ_OR_RETURN(parallel_hierarchy.NumAxes(), 2); - - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - in_distribution->clear_sbp_parallel(); - out_distribution->clear_sbp_parallel(); - // in use hint - in_distribution->CopyFrom(in_dis_hint); - - // out dim0 = broadcast - out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); - // out dim1 use hint - *out_distribution->add_sbp_parallel() = in_dis_hint.sbp_parallel(1); - - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_USER_OP("_nccl_logical_2D_same_dim0_all_gather") - .Input("in") - .Output("out") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), 2); - // (*, S(0)) -> (*, B) - CHECK_OR_RETURN(in_dis_hint.sbp_parallel(1).has_split_parallel()); - CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel(1).split_parallel().axis(), 0); - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_EQ_OR_RETURN(parallel_hierarchy.NumAxes(), 2); - - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - in_distribution->clear_sbp_parallel(); - out_distribution->clear_sbp_parallel(); - // in use hint - in_distribution->CopyFrom(in_dis_hint); - - // out dim0 use hint - *out_distribution->add_sbp_parallel() = in_dis_hint.sbp_parallel(0); - // out dim1 = broadcast - out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); - - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_USER_OP("_nccl_logical_2D_same_dim0_all_gather_noncontinuous") - .Input("in") - .Output("out") - .Attr("in_dim1_split_axis", -1) - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), 2); - // (*, S(1)) -> (*, B) - const int64_t in_split_axis = ctx->user_op_conf().attr("in_dim1_split_axis"); - CHECK_GE_OR_RETURN(in_split_axis, 1); - CHECK_OR_RETURN(in_dis_hint.sbp_parallel(1).has_split_parallel()); - CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel(1).split_parallel().axis(), in_split_axis); - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_EQ_OR_RETURN(parallel_hierarchy.NumAxes(), 2); - - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - in_distribution->clear_sbp_parallel(); - out_distribution->clear_sbp_parallel(); - // in use hint - in_distribution->CopyFrom(in_dis_hint); - - // out dim0 use hint - *out_distribution->add_sbp_parallel() = in_dis_hint.sbp_parallel(0); - // out dim1 = broadcast - out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); - - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_USER_OP("_nccl_logical_2D_same_dim0_all2all") - .Input("in") - .Output("out") - .Attr("in_dim1_split_axis", -1) - .Attr("out_dim1_split_axis", -1) - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), 2); - // (*, S(in_dim1_split_axis)) -> (*, S(out_dim1_split_axis)) - const int64_t in_split_axis = ctx->user_op_conf().attr("in_dim1_split_axis"); - const int64_t out_split_axis = ctx->user_op_conf().attr("out_dim1_split_axis"); - CHECK_OR_RETURN(in_dis_hint.sbp_parallel(1).has_split_parallel()); - CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel(1).split_parallel().axis(), in_split_axis); - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_EQ_OR_RETURN(parallel_hierarchy.NumAxes(), 2); - - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - in_distribution->clear_sbp_parallel(); - out_distribution->clear_sbp_parallel(); - // in use hint - in_distribution->CopyFrom(in_dis_hint); - - // out dim0 use hint - *out_distribution->add_sbp_parallel() = in_dis_hint.sbp_parallel(0); - // out dim1 = Split(out_split_axis) - out_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(out_split_axis); - - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/* static */ Maybe _ncclLogical_2DSameDim0AllReduceOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogical_2DSameDim0AllReduceOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe _ncclLogical_2DSameDim0AllReduceOp::InferNdSbp( + user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), 2); + CHECK_OR_RETURN(in_dis_hint.sbp_parallel(1).has_partial_sum_parallel()); + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_EQ_OR_RETURN(parallel_hierarchy.NumAxes(), 2); + + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + in_distribution->clear_sbp_parallel(); + out_distribution->clear_sbp_parallel(); + // in use hint + in_distribution->CopyFrom(in_dis_hint); + + // out dim0 use hint + *out_distribution->add_sbp_parallel() = in_dis_hint.sbp_parallel(0); + // out dim1 = broadcast + out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); + + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogical_2DSameDim0AllReduceOp::InferDataType( + user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> _ncclLogical_2DSameDim0AllReduceOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + +/* static */ Maybe _ncclLogical_2DSameDim1AllReduceOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogical_2DSameDim1AllReduceOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe _ncclLogical_2DSameDim1AllReduceOp::InferNdSbp( + user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), 2); + CHECK_OR_RETURN(in_dis_hint.sbp_parallel(0).has_partial_sum_parallel()); + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_EQ_OR_RETURN(parallel_hierarchy.NumAxes(), 2); + + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + in_distribution->clear_sbp_parallel(); + out_distribution->clear_sbp_parallel(); + // in use hint + in_distribution->CopyFrom(in_dis_hint); + + // out dim0 = broadcast + out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); + // out dim1 use hint + *out_distribution->add_sbp_parallel() = in_dis_hint.sbp_parallel(1); + + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogical_2DSameDim1AllReduceOp::InferDataType( + user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> _ncclLogical_2DSameDim1AllReduceOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + +/* static */ Maybe _ncclLogical_2DSameDim0AllGatherOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogical_2DSameDim0AllGatherOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe _ncclLogical_2DSameDim0AllGatherOp::InferNdSbp( + user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), 2); + // (*, S(0)) -> (*, B) + CHECK_OR_RETURN(in_dis_hint.sbp_parallel(1).has_split_parallel()); + CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel(1).split_parallel().axis(), 0); + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_EQ_OR_RETURN(parallel_hierarchy.NumAxes(), 2); + + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + in_distribution->clear_sbp_parallel(); + out_distribution->clear_sbp_parallel(); + // in use hint + in_distribution->CopyFrom(in_dis_hint); + + // out dim0 use hint + *out_distribution->add_sbp_parallel() = in_dis_hint.sbp_parallel(0); + // out dim1 = broadcast + out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); + + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogical_2DSameDim0AllGatherOp::InferDataType( + user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> _ncclLogical_2DSameDim0AllGatherOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + +/* static */ Maybe _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::GetSbp( + user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::InferNdSbp( + user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), 2); + // (*, S(1)) -> (*, B) + const int64_t in_split_axis = ctx->user_op_conf().attr("in_dim1_split_axis"); + CHECK_GE_OR_RETURN(in_split_axis, 1); + CHECK_OR_RETURN(in_dis_hint.sbp_parallel(1).has_split_parallel()); + CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel(1).split_parallel().axis(), in_split_axis); + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_EQ_OR_RETURN(parallel_hierarchy.NumAxes(), 2); + + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + in_distribution->clear_sbp_parallel(); + out_distribution->clear_sbp_parallel(); + // in use hint + in_distribution->CopyFrom(in_dis_hint); + + // out dim0 use hint + *out_distribution->add_sbp_parallel() = in_dis_hint.sbp_parallel(0); + // out dim1 = broadcast + out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); + + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::InferDataType( + user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> _ncclLogical_2DSameDim0AllGatherNoncontinuousOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + +/* static */ Maybe _ncclLogical_2DSameDim0All2allOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogical_2DSameDim0All2allOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe _ncclLogical_2DSameDim0All2allOp::InferNdSbp( + user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel_size(), 2); + // (*, S(in_dim1_split_axis)) -> (*, S(out_dim1_split_axis)) + const int64_t in_split_axis = ctx->user_op_conf().attr("in_dim1_split_axis"); + const int64_t out_split_axis = ctx->user_op_conf().attr("out_dim1_split_axis"); + CHECK_OR_RETURN(in_dis_hint.sbp_parallel(1).has_split_parallel()); + CHECK_EQ_OR_RETURN(in_dis_hint.sbp_parallel(1).split_parallel().axis(), in_split_axis); + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_EQ_OR_RETURN(parallel_hierarchy.NumAxes(), 2); + + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + in_distribution->clear_sbp_parallel(); + out_distribution->clear_sbp_parallel(); + // in use hint + in_distribution->CopyFrom(in_dis_hint); + + // out dim0 use hint + *out_distribution->add_sbp_parallel() = in_dis_hint.sbp_parallel(0); + // out dim1 = Split(out_split_axis) + out_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(out_split_axis); + + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogical_2DSameDim0All2allOp::InferDataType( + user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> _ncclLogical_2DSameDim0All2allOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} } // namespace oneflow diff --git a/oneflow/user/ops/nccl_logical_ops.cpp b/oneflow/user/ops/nccl_logical_ops.cpp index 48915b59e06..ef0980024b9 100644 --- a/oneflow/user/ops/nccl_logical_ops.cpp +++ b/oneflow/user/ops/nccl_logical_ops.cpp @@ -16,197 +16,232 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/operator.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("_nccl_logical_all_reduce") - .Input("in") - .Output("out") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); - for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { - CHECK_OR_RETURN(sbp_hint.has_partial_sum_parallel()); - } - - in_distribution->clear_sbp_parallel(); - out_distribution->clear_sbp_parallel(); - - // P2B - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); - for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { - in_distribution->add_sbp_parallel()->mutable_partial_sum_parallel(); - out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); - } - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_USER_OP("_nccl_logical_reduce_scatter") - .Input("in") - .Output("out") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); - for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { - CHECK_OR_RETURN(sbp_hint.has_partial_sum_parallel()); - } - - in_distribution->clear_sbp_parallel(); - out_distribution->clear_sbp_parallel(); - - // P2S - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); - for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { - in_distribution->add_sbp_parallel()->mutable_partial_sum_parallel(); - out_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(0); - } - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_USER_OP("_nccl_logical_all_gather") - .Input("in") - .Output("out") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); - for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { - CHECK_OR_RETURN(sbp_hint.has_split_parallel()); - CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), 0); - } - - in_distribution->clear_sbp_parallel(); - out_distribution->clear_sbp_parallel(); - - // S(0)->B - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); - for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { - in_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(0); - out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); - } - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_USER_OP("_nccl_logical_all_gather_noncontinuous") - .Input("in") - .Output("out") - .Attr("in_split_axis", -1) - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); - const int64_t in_split_axis = ctx->user_op_conf().attr("in_split_axis"); - CHECK_GE_OR_RETURN(in_split_axis, 1); - for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { - CHECK_OR_RETURN(sbp_hint.has_split_parallel()); - CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), in_split_axis); - } - - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - in_distribution->clear_sbp_parallel(); - out_distribution->clear_sbp_parallel(); - - // S(1)->(B) - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); - for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { - in_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(in_split_axis); - out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); - } - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_USER_OP("_nccl_logical_s2s") - .Input("in") - .Output("out") - .Attr("in_split_axis", -1) - .Attr("out_split_axis", -1) - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - const int64_t in_split_axis = ctx->user_op_conf().attr("in_split_axis"); - const int64_t out_split_axis = ctx->user_op_conf().attr("out_split_axis"); - const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); - cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); - cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); - CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); - for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { - CHECK_OR_RETURN(sbp_hint.has_split_parallel()); - CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), in_split_axis); - } - - in_distribution->clear_sbp_parallel(); - out_distribution->clear_sbp_parallel(); - - // S(in)->S(out) - const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); - CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); - for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { - in_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(in_split_axis); - out_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(out_split_axis); - } - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/* static */ Maybe _ncclLogicalAllReduceOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogicalAllReduceOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe _ncclLogicalAllReduceOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); + for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { + CHECK_OR_RETURN(sbp_hint.has_partial_sum_parallel()); + } + + in_distribution->clear_sbp_parallel(); + out_distribution->clear_sbp_parallel(); + + // P2B + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); + for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { + in_distribution->add_sbp_parallel()->mutable_partial_sum_parallel(); + out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); + } + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogicalAllReduceOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> _ncclLogicalAllReduceOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + +/* static */ Maybe _ncclLogicalReduceScatterOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogicalReduceScatterOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe _ncclLogicalReduceScatterOp::InferNdSbp( + user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); + for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { + CHECK_OR_RETURN(sbp_hint.has_partial_sum_parallel()); + } + + in_distribution->clear_sbp_parallel(); + out_distribution->clear_sbp_parallel(); + + // P2S + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); + for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { + in_distribution->add_sbp_parallel()->mutable_partial_sum_parallel(); + out_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(0); + } + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogicalReduceScatterOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> _ncclLogicalReduceScatterOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + +/* static */ Maybe _ncclLogicalAllGatherOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogicalAllGatherOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe _ncclLogicalAllGatherOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); + for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { + CHECK_OR_RETURN(sbp_hint.has_split_parallel()); + CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), 0); + } + + in_distribution->clear_sbp_parallel(); + out_distribution->clear_sbp_parallel(); + + // S(0)->B + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); + for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { + in_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(0); + out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); + } + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogicalAllGatherOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> _ncclLogicalAllGatherOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + +/* static */ Maybe _ncclLogicalAllGatherNoncontinuousOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogicalAllGatherNoncontinuousOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe _ncclLogicalAllGatherNoncontinuousOp::InferNdSbp( + user_op::InferNdSbpFnContext* ctx) { + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); + const int64_t in_split_axis = ctx->user_op_conf().attr("in_split_axis"); + CHECK_GE_OR_RETURN(in_split_axis, 1); + for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { + CHECK_OR_RETURN(sbp_hint.has_split_parallel()); + CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), in_split_axis); + } + + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + in_distribution->clear_sbp_parallel(); + out_distribution->clear_sbp_parallel(); + + // S(1)->(B) + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); + for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { + in_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(in_split_axis); + out_distribution->add_sbp_parallel()->mutable_broadcast_parallel(); + } + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogicalAllGatherNoncontinuousOp::InferDataType( + user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> _ncclLogicalAllGatherNoncontinuousOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} + +/* static */ Maybe _ncclLogicalS2sOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogicalS2sOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe _ncclLogicalS2sOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + const int64_t in_split_axis = ctx->user_op_conf().attr("in_split_axis"); + const int64_t out_split_axis = ctx->user_op_conf().attr("out_split_axis"); + const cfg::NdSbp& in_dis_hint = ctx->NdSbpHint4InputArgNameAndIndex("in", 0); + cfg::NdSbp* in_distribution = ctx->NdSbp4ArgNameAndIndex("in", 0); + cfg::NdSbp* out_distribution = ctx->NdSbp4ArgNameAndIndex("out", 0); + CHECK_GE_OR_RETURN(in_dis_hint.sbp_parallel_size(), 1); + for (const auto& sbp_hint : in_dis_hint.sbp_parallel()) { + CHECK_OR_RETURN(sbp_hint.has_split_parallel()); + CHECK_EQ_OR_RETURN(sbp_hint.split_parallel().axis(), in_split_axis); + } + + in_distribution->clear_sbp_parallel(); + out_distribution->clear_sbp_parallel(); + + // S(in)->S(out) + const Shape& parallel_hierarchy = ctx->parallel_hierarchy(); + CHECK_GE_OR_RETURN(parallel_hierarchy.NumAxes(), 1); + for (int32_t i = 0; i < parallel_hierarchy.NumAxes(); ++i) { + in_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(in_split_axis); + out_distribution->add_sbp_parallel()->mutable_split_parallel()->set_axis(out_split_axis); + } + return Maybe::Ok(); +} + +/* static */ Maybe _ncclLogicalS2sOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe> _ncclLogicalS2sOp::InferDevice( + user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} } // namespace oneflow diff --git a/oneflow/user/ops/nd_index_slice_ops.cpp b/oneflow/user/ops/nd_index_slice_ops.cpp index 2c51b18a98d..2fa17d2d390 100644 --- a/oneflow/user/ops/nd_index_slice_ops.cpp +++ b/oneflow/user/ops/nd_index_slice_ops.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -112,175 +113,207 @@ Maybe GetTensorScatterNdOptSbpSignatures(user_op::SbpContext* ctx) { } // namespace -REGISTER_USER_OP("gather_nd") - .Input("params") - .Input("indices") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& params_shape = ctx->InputShape("params", 0); - const Shape& indices_shape = ctx->InputShape("indices", 0); - int64_t index_ndims = indices_shape.At(indices_shape.NumAxes() - 1); - CHECK_LE_OR_RETURN(index_ndims, params_shape.NumAxes()); - DimVector out_shape_vec(indices_shape.dim_vec().cbegin(), indices_shape.dim_vec().cend() - 1); - FOR_RANGE(int64_t, i, index_ndims, params_shape.NumAxes()) { - out_shape_vec.emplace_back(params_shape.At(i)); - } - *ctx->OutputShape("out", 0) = Shape(out_shape_vec); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK_OR_RETURN(indices_modifier != nullptr); - indices_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& params_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("params", 0); - const user_op::TensorDesc& indices_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0); - int64_t indices_num_axes = indices_tensor.shape().NumAxes(); - FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) { - ctx->NewBuilder() - .Broadcast(user_op::OpArg("params", 0)) - .Split(user_op::OpArg("indices", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - int64_t index_ndims = indices_tensor.shape().At(indices_num_axes - 1); - FOR_RANGE(int64_t, i, index_ndims, params_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("params", 0), i) - .Broadcast(user_op::OpArg("indices", 0)) - .Split(user_op::OpArg("out", 0), i - index_ndims + indices_num_axes - 1) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("params", 0)) - .Broadcast(user_op::OpArg("indices", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("params", 0); - return Maybe::Ok(); - }); +/* static */ Maybe GatherNdOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& params_shape = ctx->InputShape("params", 0); + const Shape& indices_shape = ctx->InputShape("indices", 0); + int64_t index_ndims = indices_shape.At(indices_shape.NumAxes() - 1); + CHECK_LE_OR_RETURN(index_ndims, params_shape.NumAxes()); + DimVector out_shape_vec(indices_shape.dim_vec().cbegin(), indices_shape.dim_vec().cend() - 1); + FOR_RANGE(int64_t, i, index_ndims, params_shape.NumAxes()) { + out_shape_vec.emplace_back(params_shape.At(i)); + } + *ctx->OutputShape("out", 0) = Shape(out_shape_vec); + return Maybe::Ok(); +} -REGISTER_USER_OP("scatter_nd") - .Input("indices") - .Input("updates") - .Output("out") - .Attr("shape") - .SetTensorDescInferFn(InferScatterNdTensorDesc) - .SetDataTypeInferFn(InferScatterNdDataType) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK_OR_RETURN(indices_modifier != nullptr); - indices_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& indices_desc = - ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0); - int64_t indices_num_axes = indices_desc.shape().NumAxes(); - FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) { - ctx->NewBuilder() - .Split(user_op::OpArg("indices", 0), i) - .Split(user_op::OpArg("updates", 0), i) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - } - const Shape& out_shape = ctx->Attr("shape"); - int64_t index_ndims = indices_desc.shape().At(indices_num_axes - 1); - int64_t slice_ndims = out_shape.NumAxes() - index_ndims; - FOR_RANGE(int64_t, i, 0, slice_ndims) { - ctx->NewBuilder() - .Broadcast(user_op::OpArg("indices", 0)) - .Split(user_op::OpArg("updates", 0), i + indices_num_axes - 1) - .Split(user_op::OpArg("out", 0), i + index_ndims) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("updates", 0)) - .Broadcast(user_op::OpArg("indices", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe GatherNdOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} -REGISTER_USER_OP("scatter_nd_like") - .Input("like") - .Input("indices") - .Input("updates") - .Output("out") - .SetTensorDescInferFn(InferScatterNdLikeTensorDesc) - .SetDataTypeInferFn(InferScatterNdLikeDataType) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& indices_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0); - int64_t indices_num_axes = indices_tensor.shape().NumAxes(); - FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) { - ctx->NewBuilder() - .Broadcast(user_op::OpArg("like", 0)) - .Split(user_op::OpArg("indices", 0), i) - .Split(user_op::OpArg("updates", 0), i) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - } - const Shape& out_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape(); - int64_t index_ndims = indices_tensor.shape().At(indices_num_axes - 1); - int64_t slice_ndims = out_shape.NumAxes() - index_ndims; - FOR_RANGE(int64_t, i, 0, slice_ndims) { - ctx->NewBuilder() - .Split(user_op::OpArg("like", 0), i + index_ndims) - .Broadcast(user_op::OpArg("indices", 0)) - .Split(user_op::OpArg("updates", 0), i + indices_num_axes - 1) - .Split(user_op::OpArg("out", 0), i + index_ndims) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("like", 0)) - .PartialSum(user_op::OpArg("updates", 0)) - .Broadcast(user_op::OpArg("indices", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }); +/* static */ Maybe GatherNdOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& params_tensor = + ctx->LogicalTensorDesc4InputArgNameAndIndex("params", 0); + const user_op::TensorDesc& indices_tensor = + ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0); + int64_t indices_num_axes = indices_tensor.shape().NumAxes(); + FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) { + ctx->NewBuilder() + .Broadcast(user_op::OpArg("params", 0)) + .Split(user_op::OpArg("indices", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + int64_t index_ndims = indices_tensor.shape().At(indices_num_axes - 1); + FOR_RANGE(int64_t, i, index_ndims, params_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("params", 0), i) + .Broadcast(user_op::OpArg("indices", 0)) + .Split(user_op::OpArg("out", 0), i - index_ndims + indices_num_axes - 1) + .Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("params", 0)) + .Broadcast(user_op::OpArg("indices", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} -REGISTER_USER_OP("tensor_scatter_nd_update") - .Input("params") - .Input("updates") - .Input("indices") - .Output("out") - .SetTensorDescInferFn(InferTensorScatterNdOptTensorDesc) - .SetDataTypeInferFn(InferTensorScatterNdOptDataType) - .SetGetSbpFn(GetTensorScatterNdOptSbpSignatures) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK_OR_RETURN(indices_modifier != nullptr); - indices_modifier->set_requires_grad(false); - return Maybe::Ok(); - }); +/* static */ Maybe GatherNdOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); + CHECK_OR_RETURN(indices_modifier != nullptr); + indices_modifier->set_requires_grad(false); + return Maybe::Ok(); +} -REGISTER_USER_OP("tensor_scatter_nd_add") - .Input("params") - .Input("updates") - .Input("indices") - .Output("out") - .SetTensorDescInferFn(InferTensorScatterNdOptTensorDesc) - .SetDataTypeInferFn(InferTensorScatterNdOptDataType) - .SetGetSbpFn(GetTensorScatterNdOptSbpSignatures) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK_OR_RETURN(indices_modifier != nullptr); - indices_modifier->set_requires_grad(false); - return Maybe::Ok(); - }); +/* static */ Maybe GatherNdOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("params", 0); + return Maybe::Ok(); +} + +/* static */ Maybe ScatterNdOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferScatterNdTensorDesc(ctx); +} + +/*static*/ Maybe ScatterNdOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ScatterNdOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& indices_desc = + ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0); + int64_t indices_num_axes = indices_desc.shape().NumAxes(); + FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) { + ctx->NewBuilder() + .Split(user_op::OpArg("indices", 0), i) + .Split(user_op::OpArg("updates", 0), i) + .Broadcast(user_op::OpArg("out", 0)) + .Build(); + } + const Shape& out_shape = ctx->Attr("shape"); + int64_t index_ndims = indices_desc.shape().At(indices_num_axes - 1); + int64_t slice_ndims = out_shape.NumAxes() - index_ndims; + FOR_RANGE(int64_t, i, 0, slice_ndims) { + ctx->NewBuilder() + .Broadcast(user_op::OpArg("indices", 0)) + .Split(user_op::OpArg("updates", 0), i + indices_num_axes - 1) + .Split(user_op::OpArg("out", 0), i + index_ndims) + .Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("updates", 0)) + .Broadcast(user_op::OpArg("indices", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ScatterNdOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); + CHECK_OR_RETURN(indices_modifier != nullptr); + indices_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe ScatterNdOp::InferDataType(user_op::InferContext* ctx) { + return InferScatterNdDataType(ctx); +} + +/* static */ Maybe ScatterNdLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferScatterNdLikeTensorDesc(ctx); +} + +/*static*/ Maybe ScatterNdLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe ScatterNdLikeOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& indices_tensor = + ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0); + int64_t indices_num_axes = indices_tensor.shape().NumAxes(); + FOR_RANGE(int64_t, i, 0, indices_num_axes - 1) { + ctx->NewBuilder() + .Broadcast(user_op::OpArg("like", 0)) + .Split(user_op::OpArg("indices", 0), i) + .Split(user_op::OpArg("updates", 0), i) + .Broadcast(user_op::OpArg("out", 0)) + .Build(); + } + const Shape& out_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape(); + int64_t index_ndims = indices_tensor.shape().At(indices_num_axes - 1); + int64_t slice_ndims = out_shape.NumAxes() - index_ndims; + FOR_RANGE(int64_t, i, 0, slice_ndims) { + ctx->NewBuilder() + .Split(user_op::OpArg("like", 0), i + index_ndims) + .Broadcast(user_op::OpArg("indices", 0)) + .Split(user_op::OpArg("updates", 0), i + indices_num_axes - 1) + .Split(user_op::OpArg("out", 0), i + index_ndims) + .Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("like", 0)) + .PartialSum(user_op::OpArg("updates", 0)) + .Broadcast(user_op::OpArg("indices", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe ScatterNdLikeOp::InferDataType(user_op::InferContext* ctx) { + return InferScatterNdLikeDataType(ctx); +} + +/* static */ Maybe TensorScatterNdUpdateOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return InferTensorScatterNdOptTensorDesc(ctx); +} + +/*static*/ Maybe TensorScatterNdUpdateOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe TensorScatterNdUpdateOp::GetSbp(user_op::SbpContext* ctx) { + return GetTensorScatterNdOptSbpSignatures(ctx); +} + +/* static */ Maybe TensorScatterNdUpdateOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); + CHECK_OR_RETURN(indices_modifier != nullptr); + indices_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe TensorScatterNdUpdateOp::InferDataType(user_op::InferContext* ctx) { + return InferTensorScatterNdOptDataType(ctx); +} + +/* static */ Maybe TensorScatterNdAddOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorScatterNdOptTensorDesc(ctx); +} + +/*static*/ Maybe TensorScatterNdAddOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe TensorScatterNdAddOp::GetSbp(user_op::SbpContext* ctx) { + return GetTensorScatterNdOptSbpSignatures(ctx); +} + +/* static */ Maybe TensorScatterNdAddOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); + CHECK_OR_RETURN(indices_modifier != nullptr); + indices_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe TensorScatterNdAddOp::InferDataType(user_op::InferContext* ctx) { + return InferTensorScatterNdOptDataType(ctx); +} REGISTER_USER_OP_GRAD("gather_nd") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/nll_op.cpp b/oneflow/user/ops/nll_op.cpp index 0d71b084651..b170194aff4 100644 --- a/oneflow/user/ops/nll_op.cpp +++ b/oneflow/user/ops/nll_op.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/loss_op_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -44,7 +45,7 @@ Maybe InferTensorDescFn(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe InferDataType(user_op::InferContext* ctx) { +Maybe NllInferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_OR_RETURN(IsIndexDataType(target_desc.data_type())); @@ -88,41 +89,51 @@ Maybe InferGradDataType(user_op::InferContext* ctx) { } } // namespace -REGISTER_USER_OP("nll") - .Input("input") - .Input("target") - .OptionalInput("weight") - .Output("out") - .Output("total_weight") - .Attr("ignore_index") - .SetTensorDescInferFn(InferTensorDescFn) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); - CHECK_OR_RETURN(target_modifier != nullptr); - target_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetDataTypeInferFn(InferDataType) - .SetGetSbpFn(GenLossForwardDefaultGetSbpFn([](user_op::UserOpSbpSignatureBuilder& builder, - user_op::SbpContext* ctx) { - builder.PartialSum(user_op::OpArg("total_weight", 0)); - })); - -REGISTER_USER_OP("nll_grad") - .Input("input") - .Input("target") - .Input("total_weight") - .OptionalInput("weight") - .Input("dy") - .Output("dx") - .Attr("ignore_index") - .SetTensorDescInferFn(InferGradTensorDescFn) - .SetDataTypeInferFn(InferGradDataType) - .SetGetSbpFn(GenLossBackwardDefaultGetSbpFn([](user_op::UserOpSbpSignatureBuilder& builder, - user_op::SbpContext* ctx) { - builder.PartialSum(user_op::OpArg("total_weight", 0)); - })); +/* static */ Maybe NllOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDescFn(ctx); +} + +/*static*/ Maybe NllOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe NllOp::GetSbp(user_op::SbpContext* ctx) { + return GenLossForwardDefaultGetSbpFn( + [](user_op::UserOpSbpSignatureBuilder& builder, user_op::SbpContext* ctx) { + builder.PartialSum(user_op::OpArg("total_weight", 0)); + })(ctx); +} + +/* static */ Maybe NllOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); + CHECK_OR_RETURN(target_modifier != nullptr); + target_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe NllOp::InferDataType(user_op::InferContext* ctx) { + return NllInferDataType(ctx); +} + +/* static */ Maybe NllGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferGradTensorDescFn(ctx); +} + +/*static*/ Maybe NllGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe NllGradOp::GetSbp(user_op::SbpContext* ctx) { + return GenLossBackwardDefaultGetSbpFn( + [](user_op::UserOpSbpSignatureBuilder& builder, user_op::SbpContext* ctx) { + builder.PartialSum(user_op::OpArg("total_weight", 0)); + })(ctx); +} + +/* static */ Maybe NllGradOp::InferDataType(user_op::InferContext* ctx) { + return InferGradDataType(ctx); +} REGISTER_USER_OP_GRAD("nll").SetGenBackwardOpConfFn( [](const user_op::UserOpWrapper& op, const user_op::AddOpFn& AddOp) -> Maybe { diff --git a/oneflow/user/ops/nms_op.cpp b/oneflow/user/ops/nms_op.cpp index 20dcbbf5901..1d9c0e29537 100644 --- a/oneflow/user/ops/nms_op.cpp +++ b/oneflow/user/ops/nms_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -31,13 +32,20 @@ Maybe InferNmsDataType(user_op::InferContext* ctx) { } // namespace -REGISTER_USER_OP("nms") - .Input("in") - .Output("out") - .Attr("iou_threshold") - .Attr("keep_n") - .SetTensorDescInferFn(InferNmsTensorDesc) - .SetDataTypeInferFn(InferNmsDataType) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/* static */ Maybe NmsOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferNmsTensorDesc(ctx); +} + +/*static*/ Maybe NmsOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe NmsOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} + +/* static */ Maybe NmsOp::InferDataType(user_op::InferContext* ctx) { + return InferNmsDataType(ctx); +} } // namespace oneflow diff --git a/oneflow/user/ops/normalization_op.cpp b/oneflow/user/ops/normalization_op.cpp index df95d677b95..2444ca274a4 100644 --- a/oneflow/user/ops/normalization_op.cpp +++ b/oneflow/user/ops/normalization_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" #ifdef WITH_CUDA #include "oneflow/core/device/cuda_util.h" #include "oneflow/core/device/cudnn_util.h" @@ -212,77 +213,82 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() { user_op::TensorDesc* reserve_space)>()); } -REGISTER_USER_OP("normalization") - .Input("x") - .OptionalInput("moving_mean") - .OptionalInput("moving_variance") - .Input("gamma") - .Input("beta") - .OptionalInput("_add_to_output") - .Output("y") - .OptionalOutput("mean") - .OptionalOutput("inv_variance") - .Attr("axis") - .Attr("epsilon") - .Attr("training") - .Attr("momentum") - .SetInputArgModifyFn(FwInputArgModifyFn) - .SetTensorDescInferFn(MakeFwTensorDescInferFn()) - .SetGetSbpFn(FwGetSbpFn) - .SetDataTypeInferFn(MakeFwDataTypeInferFn()); - -REGISTER_USER_OP("normalization_add_relu") - .Input("x") - .OptionalInput("addend") - .OptionalInput("moving_mean") - .OptionalInput("moving_variance") - .Input("gamma") - .Input("beta") - .Output("y") - .Output("reserve_space") - .OptionalOutput("mean") - .OptionalOutput("inv_variance") - .Attr("axis") - .Attr("epsilon") - .Attr("training") - .Attr("momentum") - .SetInputArgModifyFn(FwInputArgModifyFn) - .SetLogicalTensorDescInferFn( - MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, - user_op::TensorDesc* reserve_space) -> Maybe { - const auto& x_desc = ctx->InputTensorDesc("x", 0); - size_t reserve_space_bits = x_desc.shape().elem_cnt(); - int64_t parallel_num = ctx->parallel_num(); - if (parallel_num != 1) { - // There no need to call SbpParallel4ArgNameAndIndex when parallel_num = 1 in local. - const cfg::SbpParallel& x_sbp = ctx->SbpParallel4ArgNameAndIndex("x", 0); - if (x_sbp.has_split_parallel()) { - CHECK_EQ_OR_RETURN(x_sbp.split_parallel().axis(), 0); - reserve_space_bits = reserve_space_bits / ctx->parallel_num(); - } - } - *reserve_space->mut_shape() = - Shape({static_cast(RoundUp(reserve_space_bits, 32) / 32)}); - return Maybe::Ok(); - })) - .SetPhysicalTensorDescInferFn( - MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, - user_op::TensorDesc* reserve_space) -> Maybe { - const auto& x_desc = ctx->InputTensorDesc("x", 0); - *reserve_space->mut_shape() = - Shape({static_cast(RoundUp(x_desc.shape().elem_cnt(), 32) / 32)}); - return Maybe::Ok(); - })) - .SetGetSbpFn(FwGetSbpFn) - .SetDataTypeInferFn( - MakeFwDataTypeInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, - user_op::TensorDesc* reserve_space) -> Maybe { - *reserve_space->mut_data_type() = DataType::kInt32; - return Maybe::Ok(); - })); +} // namespace + +/* static */ Maybe NormalizationOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return MakeFwTensorDescInferFn()(ctx); +} + +/*static*/ Maybe NormalizationOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe NormalizationOp::GetSbp(user_op::SbpContext* ctx) { + return FwGetSbpFn(ctx); +} + +/* static */ Maybe NormalizationOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return FwInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe NormalizationOp::InferDataType(user_op::InferContext* ctx) { + return MakeFwDataTypeInferFn()(ctx); +} + +/* static */ Maybe NormalizationAddReluOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, + user_op::TensorDesc* reserve_space) -> Maybe { + const auto& x_desc = ctx->InputTensorDesc("x", 0); + size_t reserve_space_bits = x_desc.shape().elem_cnt(); + int64_t parallel_num = ctx->parallel_num(); + if (parallel_num != 1) { + // There no need to call SbpParallel4ArgNameAndIndex when parallel_num = 1 in local. + const cfg::SbpParallel& x_sbp = ctx->SbpParallel4ArgNameAndIndex("x", 0); + if (x_sbp.has_split_parallel()) { + CHECK_EQ_OR_RETURN(x_sbp.split_parallel().axis(), 0); + reserve_space_bits = reserve_space_bits / ctx->parallel_num(); + } + } + *reserve_space->mut_shape() = + Shape({static_cast(RoundUp(reserve_space_bits, 32) / 32)}); + return Maybe::Ok(); + })(ctx); +} + +/* static */ Maybe NormalizationAddReluOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, + user_op::TensorDesc* reserve_space) -> Maybe { + const auto& x_desc = ctx->InputTensorDesc("x", 0); + *reserve_space->mut_shape() = + Shape({static_cast(RoundUp(x_desc.shape().elem_cnt(), 32) / 32)}); + return Maybe::Ok(); + })(ctx); +} + +/* static */ Maybe NormalizationAddReluOp::GetSbp(user_op::SbpContext* ctx) { + return FwGetSbpFn(ctx); +} + +/* static */ Maybe NormalizationAddReluOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return FwInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe NormalizationAddReluOp::InferDataType(user_op::InferContext* ctx) { + return MakeFwDataTypeInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, + user_op::TensorDesc* reserve_space) -> Maybe { + *reserve_space->mut_data_type() = DataType::kInt32; + return Maybe::Ok(); + })(ctx); +} #if defined(WITH_CUDA) && (CUDNN_VERSION >= 7401) +namespace { + void InferCudnnReserveSpaceSize(DataType data_type, cudnnBatchNormOps_t ops, int64_t n, int64_t c, int64_t h, int64_t w, size_t* reserve_space_size) { cudnnHandle_t cudnn_handle; @@ -295,79 +301,110 @@ void InferCudnnReserveSpaceSize(DataType data_type, cudnnBatchNormOps_t ops, int OF_CUDNN_CHECK(cudnnDestroy(cudnn_handle)); } -REGISTER_USER_OP("cudnn_fused_normalization_add_relu") - .Input("x") - .OptionalInput("addend") - .OptionalInput("moving_mean") - .OptionalInput("moving_variance") - .Input("gamma") - .Input("beta") - .Output("y") - .Output("reserve_space") - .OptionalOutput("mean") - .OptionalOutput("inv_variance") - .Attr("axis") - .Attr("epsilon") - .Attr("momentum") - .SetInputArgModifyFn(FwInputArgModifyFn) - .SetLogicalTensorDescInferFn( - MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, - user_op::TensorDesc* reserve_space) -> Maybe { - const Shape& x_shape = x->shape(); - const auto axis = ctx->Attr("axis"); - CHECK_EQ_OR_RETURN(x_shape.Count(axis + 1), 1); - int64_t n = x_shape.At(0); - int64_t h = x_shape.Count(1, axis); - int64_t w = 1; - int64_t c = x_shape.At(axis); - const auto& x_sbp = ctx->SbpParallel4ArgNameAndIndex("x", 0); - if (x_sbp.has_split_parallel()) { - CHECK_EQ_OR_RETURN(x_sbp.split_parallel().axis(), 0); - n = n / ctx->parallel_num(); - } - cudnnBatchNormOps_t ops; - if (ctx->has_input("addend", 0)) { - ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION; - } else { - ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION; - } - size_t reserve_space_size; - InferCudnnReserveSpaceSize(x->data_type(), ops, n, c, h, w, &reserve_space_size); - reserve_space_size = std::max(reserve_space_size, GetOneVal()); - *reserve_space->mut_shape() = Shape({static_cast(reserve_space_size)}); - return Maybe::Ok(); - })) - .SetPhysicalTensorDescInferFn( - MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, - user_op::TensorDesc* reserve_space) -> Maybe { - const Shape& x_shape = x->shape(); - const auto axis = ctx->Attr("axis"); - CHECK_EQ_OR_RETURN(x_shape.Count(axis + 1), 1); - int64_t n = x_shape.At(0); - int64_t h = x_shape.Count(1, axis); - int64_t w = 1; - int64_t c = x_shape.At(axis); - cudnnBatchNormOps_t ops; - if (ctx->has_input("addend", 0)) { - ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION; - } else { - ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION; - } - size_t reserve_space_size; - InferCudnnReserveSpaceSize(x->data_type(), ops, n, c, h, w, &reserve_space_size); - reserve_space_size = std::max(reserve_space_size, GetOneVal()); - *reserve_space->mut_shape() = Shape({static_cast(reserve_space_size)}); - return Maybe::Ok(); - })) - .SetGetSbpFn(FwGetSbpFn) - .SetDataTypeInferFn( - MakeFwDataTypeInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, - user_op::TensorDesc* reserve_space) -> Maybe { - *reserve_space->mut_data_type() = DataType::kChar; - return Maybe::Ok(); - })); +} // namespace -#endif +/* static */ Maybe CudnnFusedNormalizationAddReluOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, + user_op::TensorDesc* reserve_space) -> Maybe { + const Shape& x_shape = x->shape(); + const auto axis = ctx->Attr("axis"); + CHECK_EQ_OR_RETURN(x_shape.Count(axis + 1), 1); + int64_t n = x_shape.At(0); + int64_t h = x_shape.Count(1, axis); + int64_t w = 1; + int64_t c = x_shape.At(axis); + const auto& x_sbp = ctx->SbpParallel4ArgNameAndIndex("x", 0); + if (x_sbp.has_split_parallel()) { + CHECK_EQ_OR_RETURN(x_sbp.split_parallel().axis(), 0); + n = n / ctx->parallel_num(); + } + cudnnBatchNormOps_t ops; + if (ctx->has_input("addend", 0)) { + ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION; + } else { + ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION; + } + size_t reserve_space_size; + InferCudnnReserveSpaceSize(x->data_type(), ops, n, c, h, w, &reserve_space_size); + reserve_space_size = std::max(reserve_space_size, GetOneVal()); + *reserve_space->mut_shape() = Shape({static_cast(reserve_space_size)}); + return Maybe::Ok(); + })(ctx); +} + +/* static */ Maybe CudnnFusedNormalizationAddReluOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return MakeFwTensorDescInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, + user_op::TensorDesc* reserve_space) -> Maybe { + const Shape& x_shape = x->shape(); + const auto axis = ctx->Attr("axis"); + CHECK_EQ_OR_RETURN(x_shape.Count(axis + 1), 1); + int64_t n = x_shape.At(0); + int64_t h = x_shape.Count(1, axis); + int64_t w = 1; + int64_t c = x_shape.At(axis); + cudnnBatchNormOps_t ops; + if (ctx->has_input("addend", 0)) { + ops = CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION; + } else { + ops = CUDNN_BATCHNORM_OPS_BN_ACTIVATION; + } + size_t reserve_space_size; + InferCudnnReserveSpaceSize(x->data_type(), ops, n, c, h, w, &reserve_space_size); + reserve_space_size = std::max(reserve_space_size, GetOneVal()); + *reserve_space->mut_shape() = Shape({static_cast(reserve_space_size)}); + return Maybe::Ok(); + })(ctx); +} + +/* static */ Maybe CudnnFusedNormalizationAddReluOp::GetSbp(user_op::SbpContext* ctx) { + return FwGetSbpFn(ctx); +} + +/* static */ Maybe CudnnFusedNormalizationAddReluOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return FwInputArgModifyFn(GetInputArgModifierFn, conf); +} + +/* static */ Maybe CudnnFusedNormalizationAddReluOp::InferDataType( + user_op::InferContext* ctx) { + return MakeFwDataTypeInferFn([](user_op::InferContext* ctx, const user_op::TensorDesc* x, + user_op::TensorDesc* reserve_space) -> Maybe { + *reserve_space->mut_data_type() = DataType::kChar; + return Maybe::Ok(); + })(ctx); +} + +#else + +/* static */ Maybe CudnnFusedNormalizationAddReluOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; +} + +/* static */ Maybe CudnnFusedNormalizationAddReluOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; +} + +/* static */ Maybe CudnnFusedNormalizationAddReluOp::GetSbp(user_op::SbpContext* ctx) { + return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; +} + +/* static */ Maybe CudnnFusedNormalizationAddReluOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; +} + +/* static */ Maybe CudnnFusedNormalizationAddReluOp::InferDataType( + user_op::InferContext* ctx) { + return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; +} + +#endif // WITH_CUDA + +namespace { Maybe BwTensorDescInferFn(user_op::InferContext* ctx) { #ifdef WITH_CUDA @@ -447,60 +484,83 @@ Maybe BwGetSbpFn(user_op::SbpContext* ctx) { return Maybe::Ok(); } -REGISTER_USER_OP("normalization_grad") - .Input("x") - .Input("dy") - .Input("mean") - .Input("inv_variance") - .Input("gamma") - .Output("gamma_diff") - .Output("beta_diff") - .Output("dx") - .Attr("axis") - .Attr("epsilon") - .SetTensorDescInferFn(BwTensorDescInferFn) - .SetGetSbpFn(BwGetSbpFn) - .SetDataTypeInferFn(BwDataTypeInferFn); - -REGISTER_USER_OP("normalization_add_relu_grad") - .Input("x") - .Input("dy") - .Input("mean") - .Input("inv_variance") - .Input("gamma") - .Input("beta") - .Input("reserve_space") - .Input("y") - .Output("gamma_diff") - .Output("beta_diff") - .Output("dx") - .OptionalOutput("addend_diff") - .Attr("axis") - .Attr("epsilon") - .SetTensorDescInferFn(BwTensorDescInferFn) - .SetGetSbpFn(BwGetSbpFn) - .SetDataTypeInferFn(BwDataTypeInferFn); +} // namespace + +/* static */ Maybe NormalizationGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return BwTensorDescInferFn(ctx); +} + +/*static*/ Maybe NormalizationGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe NormalizationGradOp::GetSbp(user_op::SbpContext* ctx) { + return BwGetSbpFn(ctx); +} + +/* static */ Maybe NormalizationGradOp::InferDataType(user_op::InferContext* ctx) { + return BwDataTypeInferFn(ctx); +} + +/* static */ Maybe NormalizationAddReluGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return BwTensorDescInferFn(ctx); +} + +/*static*/ Maybe NormalizationAddReluGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe NormalizationAddReluGradOp::GetSbp(user_op::SbpContext* ctx) { + return BwGetSbpFn(ctx); +} + +/* static */ Maybe NormalizationAddReluGradOp::InferDataType(user_op::InferContext* ctx) { + return BwDataTypeInferFn(ctx); +} #if defined(WITH_CUDA) && (CUDNN_VERSION >= 7401) -REGISTER_USER_OP("cudnn_fused_normalization_add_relu_grad") - .Input("x") - .Input("dy") - .Input("mean") - .Input("inv_variance") - .Input("gamma") - .Input("beta") - .Input("reserve_space") - .Input("y") - .Output("gamma_diff") - .Output("beta_diff") - .Output("dx") - .OptionalOutput("addend_diff") - .Attr("axis") - .Attr("epsilon") - .SetTensorDescInferFn(BwTensorDescInferFn) - .SetGetSbpFn(BwGetSbpFn) - .SetDataTypeInferFn(BwDataTypeInferFn); +/* static */ Maybe CudnnFusedNormalizationAddReluGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return BwTensorDescInferFn(ctx); +} + +/*static*/ Maybe CudnnFusedNormalizationAddReluGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe CudnnFusedNormalizationAddReluGradOp::GetSbp(user_op::SbpContext* ctx) { + return BwGetSbpFn(ctx); +} + +/* static */ Maybe CudnnFusedNormalizationAddReluGradOp::InferDataType( + user_op::InferContext* ctx) { + return BwDataTypeInferFn(ctx); +} + +#else + +/* static */ Maybe CudnnFusedNormalizationAddReluGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; +} + +/*static*/ Maybe CudnnFusedNormalizationAddReluGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; +} + +/* static */ Maybe CudnnFusedNormalizationAddReluGradOp::GetSbp(user_op::SbpContext* ctx) { + return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; +} + +/* static */ Maybe CudnnFusedNormalizationAddReluGradOp::InferDataType( + user_op::InferContext* ctx) { + return Error::UnimplementedError() << "require CUDA and CuDNN >= 7401"; +} #endif @@ -709,6 +769,4 @@ REGISTER_USER_OP_GRAD("normalization_add_relu") return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/nvtx_range_op.cpp b/oneflow/user/ops/nvtx_range_op.cpp index e0a222a7bd5..0f2bd54b2e6 100644 --- a/oneflow/user/ops/nvtx_range_op.cpp +++ b/oneflow/user/ops/nvtx_range_op.cpp @@ -13,67 +13,103 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ + #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +#ifdef WITH_CUDA -REGISTER_USER_OP("nvtx_start") - .Input("in") - .Output("out") - .Attr("mark_prefix") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/* static */ Maybe NvtxStartOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("nvtx_end") - .Input("in") - .Output("out") - .Attr("mark_prefix") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe NvtxStartOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe NvtxStartOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe NvtxStartOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +/* static */ Maybe NvtxEndOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe NvtxEndOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe NvtxEndOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} + +/* static */ Maybe NvtxEndOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} + +#else + +/* static */ Maybe NvtxStartOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return Error::UnimplementedError() << "require CUDA to use NVTX"; +} + +/*static*/ Maybe NvtxStartOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe NvtxStartOp::GetSbp(user_op::SbpContext* ctx) { + return Error::UnimplementedError() << "require CUDA to use NVTX"; +} + +/* static */ Maybe NvtxStartOp::InferDataType(user_op::InferContext* ctx) { + return Error::UnimplementedError() << "require CUDA to use NVTX"; +} + +/* static */ Maybe NvtxEndOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return Error::UnimplementedError() << "require CUDA to use NVTX"; +} + +/*static*/ Maybe NvtxEndOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return Error::UnimplementedError() << "require CUDA to use NVTX"; +} + +/* static */ Maybe NvtxEndOp::GetSbp(user_op::SbpContext* ctx) { + return Error::UnimplementedError() << "require CUDA to use NVTX"; +} + +/* static */ Maybe NvtxEndOp::InferDataType(user_op::InferContext* ctx) { + return Error::UnimplementedError() << "require CUDA to use NVTX"; +} + +#endif // WITH_CUDA REGISTER_USER_OP_GRAD("nvtx_start") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -108,6 +144,5 @@ REGISTER_USER_OP_GRAD("nvtx_end") } return Maybe::Ok(); }); -} // namespace } // namespace oneflow diff --git a/oneflow/user/ops/ofrecord_decoder_ops.cpp b/oneflow/user/ops/ofrecord_decoder_ops.cpp index ff19fe78952..02ccf542062 100644 --- a/oneflow/user/ops/ofrecord_decoder_ops.cpp +++ b/oneflow/user/ops/ofrecord_decoder_ops.cpp @@ -15,148 +15,148 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/balanced_splitter.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("ofrecord_raw_decoder") - .Input("in") - .Output("out") - .Attr("name") - .Attr("shape") - .Attr("data_type") - .Attr("dim1_varying_length", false) - .Attr("truncate", false) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1); - Shape conf_shape = ctx->Attr("shape"); - DimVector dim_vec(1 + conf_shape.NumAxes()); - dim_vec[0] = in_tensor.shape().At(0); - for (int i = 1; i < dim_vec.size(); ++i) { dim_vec[i] = conf_shape.At(i - 1); } - *out_tensor->mut_shape() = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); - CHECK_NOTNULL_OR_RETURN(in_modifier); - in_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), 0) - .Split(user_op::OpArg("out", 0), 0) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(in_tensor.data_type() == DataType::kOFRecord); - *out_tensor->mut_data_type() = ctx->Attr("data_type"); - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("ofrecord_bytes_decoder") - .Input("in") - .Output("out") - .Attr("name") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - *out->mut_is_dynamic() = in.is_dynamic(); - *out->mut_shape() = in.shape(); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); - CHECK_NOTNULL_OR_RETURN(in_modifier); - in_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::SplitForEachAxis) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(in.data_type() == DataType::kOFRecord); - *out->mut_data_type() = DataType::kTensorBuffer; - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("ofrecord_image_decoder") - .Input("in") - .Output("out") - .Attr("name") - .Attr("color_space", "BGR") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1); - *out_tensor->mut_shape() = in_tensor.shape(); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); - CHECK_NOTNULL_OR_RETURN(in_modifier); - in_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), 0) - .Split(user_op::OpArg("out", 0), 0) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(in_tensor.data_type() == DataType::kOFRecord); - *out_tensor->mut_data_type() = DataType::kTensorBuffer; - return Maybe::Ok(); - }); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("ofrecord_image_decoder_random_crop") - .Input("in") - .Output("out") - .Attr("name") - .Attr("color_space", "BGR") - .Attr("num_attempts", 10) - .Attr("seed", -1) - .Attr("has_seed", false) - .Attr>("random_area", {0.08, 1.0}) - .Attr>("random_aspect_ratio", {0.75, 1.333333}) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1); - *out_tensor->mut_shape() = in_tensor.shape(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), 0) - .Split(user_op::OpArg("out", 0), 0) - .Build(); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); - CHECK_NOTNULL_OR_RETURN(in_modifier); - in_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(in_tensor.data_type() == DataType::kOFRecord); - *out_tensor->mut_data_type() = DataType::kTensorBuffer; - return Maybe::Ok(); - }); +/* static */ Maybe OfrecordRawDecoderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1); + Shape conf_shape = ctx->Attr("shape"); + DimVector dim_vec(1 + conf_shape.NumAxes()); + dim_vec[0] = in_tensor.shape().At(0); + for (int i = 1; i < dim_vec.size(); ++i) { dim_vec[i] = conf_shape.At(i - 1); } + *out_tensor->mut_shape() = Shape(dim_vec); + return Maybe::Ok(); +} + +/*static*/ Maybe OfrecordRawDecoderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe OfrecordRawDecoderOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), 0).Split(user_op::OpArg("out", 0), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordRawDecoderOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); + CHECK_NOTNULL_OR_RETURN(in_modifier); + in_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordRawDecoderOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(in_tensor.data_type() == DataType::kOFRecord); + *out_tensor->mut_data_type() = ctx->Attr("data_type"); + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordBytesDecoderOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + *out->mut_is_dynamic() = in.is_dynamic(); + *out->mut_shape() = in.shape(); + return Maybe::Ok(); +} + +/*static*/ Maybe OfrecordBytesDecoderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe OfrecordBytesDecoderOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); +} + +/* static */ Maybe OfrecordBytesDecoderOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); + CHECK_NOTNULL_OR_RETURN(in_modifier); + in_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordBytesDecoderOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(in.data_type() == DataType::kOFRecord); + *out->mut_data_type() = DataType::kTensorBuffer; + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordImageDecoderOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1); + *out_tensor->mut_shape() = in_tensor.shape(); + return Maybe::Ok(); +} + +/*static*/ Maybe OfrecordImageDecoderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe OfrecordImageDecoderOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), 0).Split(user_op::OpArg("out", 0), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordImageDecoderOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); + CHECK_NOTNULL_OR_RETURN(in_modifier); + in_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordImageDecoderOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(in_tensor.data_type() == DataType::kOFRecord); + *out_tensor->mut_data_type() = DataType::kTensorBuffer; + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordImageDecoderRandomCropOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1); + *out_tensor->mut_shape() = in_tensor.shape(); + return Maybe::Ok(); +} + +/*static*/ Maybe OfrecordImageDecoderRandomCropOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe OfrecordImageDecoderRandomCropOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), 0).Split(user_op::OpArg("out", 0), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordImageDecoderRandomCropOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); + CHECK_NOTNULL_OR_RETURN(in_modifier); + in_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordImageDecoderRandomCropOp::InferDataType( + user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(in_tensor.data_type() == DataType::kOFRecord); + *out_tensor->mut_data_type() = DataType::kTensorBuffer; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/ofrecord_image_classification_reader_op.cpp b/oneflow/user/ops/ofrecord_image_classification_reader_op.cpp index 2e73940718f..5e1c21cc54c 100644 --- a/oneflow/user/ops/ofrecord_image_classification_reader_op.cpp +++ b/oneflow/user/ops/ofrecord_image_classification_reader_op.cpp @@ -14,66 +14,57 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("ofrecord_image_classification_reader") - .Output("image") - .Output("label") - .Attr("data_dir") - .Attr("data_part_num") - .Attr("batch_size") - .Attr("part_name_prefix", "part-") - .Attr("part_name_suffix_length", -1) - .Attr("random_shuffle", false) - .Attr("seed", -1) - .Attr("shuffle_buffer_size", 1024) - .Attr("shuffle_after_epoch", false) - .Attr("color_space", "BGR") - .Attr("image_feature_name", "encoded") - .Attr("label_feature_name", "class/label") - .Attr("decode_buffer_size_per_thread", 8) - .Attr("num_decode_threads_per_machine", 0) - .SetPhysicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* image_tensor = ctx->OutputTensorDesc("image", 0); - user_op::TensorDesc* label_tensor = ctx->OutputTensorDesc("label", 0); - int32_t local_batch_size = ctx->Attr("batch_size"); - const cfg::SbpParallel& sbp = ctx->SbpParallel4ArgNameAndIndex("image", 0); - int64_t parallel_num = ctx->parallel_ctx().parallel_num(); - if (sbp.has_split_parallel() && parallel_num > 1) { - CHECK_EQ_OR_RETURN(local_batch_size % parallel_num, 0); - local_batch_size /= parallel_num; - } - *image_tensor->mut_shape() = Shape({local_batch_size}); - *label_tensor->mut_shape() = Shape({local_batch_size}); - return Maybe::Ok(); - }) - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* image_tensor = ctx->OutputTensorDesc("image", 0); - user_op::TensorDesc* label_tensor = ctx->OutputTensorDesc("label", 0); - int32_t batch_size = ctx->Attr("batch_size"); - *image_tensor->mut_shape() = Shape({batch_size}); - *label_tensor->mut_shape() = Shape({batch_size}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetOutputArgModifyFn([](user_op::GetOutputArgModifier GetOutputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> Maybe { - user_op::OutputArgModifier* image_modifier = GetOutputArgModifierFn("image", 0); - CHECK_OR_RETURN(image_modifier != nullptr); - image_modifier->set_header_infered_before_compute(false); - user_op::OutputArgModifier* label_modifier = GetOutputArgModifierFn("label", 0); - CHECK_OR_RETURN(label_modifier != nullptr); - label_modifier->set_header_infered_before_compute(false); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("image", 0) = DataType::kTensorBuffer; - *ctx->OutputDType("label", 0) = DataType::kTensorBuffer; - return Maybe::Ok(); - }); +/* static */ Maybe OfrecordImageClassificationReaderOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + user_op::TensorDesc* image_tensor = ctx->OutputTensorDesc("image", 0); + user_op::TensorDesc* label_tensor = ctx->OutputTensorDesc("label", 0); + int32_t batch_size = ctx->Attr("batch_size"); + *image_tensor->mut_shape() = Shape({batch_size}); + *label_tensor->mut_shape() = Shape({batch_size}); + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordImageClassificationReaderOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + user_op::TensorDesc* image_tensor = ctx->OutputTensorDesc("image", 0); + user_op::TensorDesc* label_tensor = ctx->OutputTensorDesc("label", 0); + int32_t local_batch_size = ctx->Attr("batch_size"); + const cfg::SbpParallel& sbp = ctx->SbpParallel4ArgNameAndIndex("image", 0); + int64_t parallel_num = ctx->parallel_ctx().parallel_num(); + if (sbp.has_split_parallel() && parallel_num > 1) { + CHECK_EQ_OR_RETURN(local_batch_size % parallel_num, 0); + local_batch_size /= parallel_num; + } + *image_tensor->mut_shape() = Shape({local_batch_size}); + *label_tensor->mut_shape() = Shape({local_batch_size}); + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordImageClassificationReaderOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordImageClassificationReaderOp::ModifyOutputArg( + const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::OutputArgModifier* image_modifier = GetOutputArgModifierFn("image", 0); + CHECK_OR_RETURN(image_modifier != nullptr); + image_modifier->set_header_infered_before_compute(false); + user_op::OutputArgModifier* label_modifier = GetOutputArgModifierFn("label", 0); + CHECK_OR_RETURN(label_modifier != nullptr); + label_modifier->set_header_infered_before_compute(false); + return Maybe::Ok(); +} + +/* static */ Maybe OfrecordImageClassificationReaderOp::InferDataType( + user_op::InferContext* ctx) { + *ctx->OutputDType("image", 0) = DataType::kTensorBuffer; + *ctx->OutputDType("label", 0) = DataType::kTensorBuffer; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/ofrecord_reader_op.cpp b/oneflow/user/ops/ofrecord_reader_op.cpp index 5e29638c773..475d058f00b 100644 --- a/oneflow/user/ops/ofrecord_reader_op.cpp +++ b/oneflow/user/ops/ofrecord_reader_op.cpp @@ -14,59 +14,53 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("OFRecordReader") - .Output("out") - .Attr("data_dir") - .Attr("data_part_num") - .Attr("batch_size") - .Attr("part_name_prefix", "part-") - .Attr("part_name_suffix_length", -1) - .Attr("random_shuffle", false) - .Attr("seed", -1) - .Attr("shuffle_buffer_size", 1024) - .Attr("shuffle_after_epoch", false) - .Attr>("nd_sbp") - .SetPhysicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - int32_t batch_size = ctx->Attr("batch_size"); - const cfg::SbpParallel& sbp = ctx->SbpParallel4ArgNameAndIndex("out", 0); - int64_t parallel_num = ctx->parallel_ctx().parallel_num(); - if (sbp.has_split_parallel() && parallel_num > 1) { - CHECK_EQ_OR_RETURN(batch_size % parallel_num, 0); - batch_size /= parallel_num; - } - *out_tensor->mut_shape() = Shape({batch_size}); - return Maybe::Ok(); - }) - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - *out_tensor->mut_shape() = Shape({ctx->Attr("batch_size")}); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = DataType::kOFRecord; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - cfg::SbpParallel default_sbp; - default_sbp.mutable_split_parallel()->set_axis(0); - return user_op::InferNdSbp4SrcOp(ctx, default_sbp); - }) - .SetOutputArgModifyFn([](user_op::GetOutputArgModifier GetOutputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> Maybe { - user_op::OutputArgModifier* out_modifier = GetOutputArgModifierFn("out", 0); - CHECK_OR_RETURN(out_modifier != nullptr); - // NOTE(chengcheng): OFRecordReader Only support static shape infer which will read all batch - // size data with output shape (batch_size,) - // out_modifier->set_header_infered_before_compute(false); - return Maybe::Ok(); - }); +/* static */ Maybe OFRecordReaderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + *out_tensor->mut_shape() = Shape({ctx->Attr("batch_size")}); + return Maybe::Ok(); +} + +/* static */ Maybe OFRecordReaderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + int32_t batch_size = ctx->Attr("batch_size"); + const cfg::SbpParallel& sbp = ctx->SbpParallel4ArgNameAndIndex("out", 0); + int64_t parallel_num = ctx->parallel_ctx().parallel_num(); + if (sbp.has_split_parallel() && parallel_num > 1) { + CHECK_EQ_OR_RETURN(batch_size % parallel_num, 0); + batch_size /= parallel_num; + } + *out_tensor->mut_shape() = Shape({batch_size}); + return Maybe::Ok(); +} + +/* static */ Maybe OFRecordReaderOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe OFRecordReaderOp::ModifyOutputArg( + const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::OutputArgModifier* out_modifier = GetOutputArgModifierFn("out", 0); + CHECK_OR_RETURN(out_modifier != nullptr); + // NOTE(chengcheng): OFRecordReader Only support static shape infer which will read all batch + // size data with output shape (batch_size,) + // out_modifier->set_header_infered_before_compute(false); + return Maybe::Ok(); +} + +/* static */ Maybe OFRecordReaderOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + cfg::SbpParallel default_sbp; + default_sbp.mutable_split_parallel()->set_axis(0); + return user_op::InferNdSbp4SrcOp(ctx, default_sbp); +} + +/* static */ Maybe OFRecordReaderOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = DataType::kOFRecord; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/one_hot_op.cpp b/oneflow/user/ops/one_hot_op.cpp index c28be85e05f..ded1daf6e20 100644 --- a/oneflow/user/ops/one_hot_op.cpp +++ b/oneflow/user/ops/one_hot_op.cpp @@ -15,56 +15,55 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/balanced_splitter.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("one_hot") - .Input("indices") - .Output("out") - .Attr("depth") - .Attr("floating_on_value") - .Attr("integer_on_value") - .Attr("floating_off_value") - .Attr("integer_off_value") - .Attr("dtype") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const int64_t depth = ctx->Attr("depth"); - CHECK_GT_OR_RETURN(depth, 0); - const user_op::TensorDesc& indices_desc = ctx->InputTensorDesc("indices", 0); - CHECK_GT_OR_RETURN(indices_desc.shape().NumAxes(), 0); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_is_dynamic() = indices_desc.is_dynamic(); - DimVector dim_vec = indices_desc.shape().dim_vec(); - dim_vec.emplace_back(depth); - *out_desc->mut_shape() = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); - CHECK_OR_RETURN(indices_modifier != nullptr); - indices_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& indices_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0); - FOR_RANGE(int64_t, i, 0, indices_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("indices", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } +/* static */ Maybe OneHotOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const int64_t depth = ctx->Attr("depth"); + CHECK_GT_OR_RETURN(depth, 0); + const user_op::TensorDesc& indices_desc = ctx->InputTensorDesc("indices", 0); + CHECK_GT_OR_RETURN(indices_desc.shape().NumAxes(), 0); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_is_dynamic() = indices_desc.is_dynamic(); + DimVector dim_vec = indices_desc.shape().dim_vec(); + dim_vec.emplace_back(depth); + *out_desc->mut_shape() = Shape(dim_vec); + return Maybe::Ok(); +} - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& indices_desc = ctx->InputTensorDesc("indices", 0); - CHECK_OR_RETURN(IsIndexDataType(indices_desc.data_type())); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - DataType dtype = ctx->Attr("dtype"); - *out_desc->mut_data_type() = dtype; - return Maybe::Ok(); - }); +/*static*/ Maybe OneHotOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe OneHotOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& indices_tensor = + ctx->LogicalTensorDesc4InputArgNameAndIndex("indices", 0); + FOR_RANGE(int64_t, i, 0, indices_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("indices", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + + return Maybe::Ok(); +} + +/* static */ Maybe OneHotOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* indices_modifier = GetInputArgModifierFn("indices", 0); + CHECK_OR_RETURN(indices_modifier != nullptr); + indices_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe OneHotOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& indices_desc = ctx->InputTensorDesc("indices", 0); + CHECK_OR_RETURN(IsIndexDataType(indices_desc.data_type())); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + DataType dtype = ctx->Attr("dtype"); + *out_desc->mut_data_type() = dtype; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/onerec_decoder_op.cpp b/oneflow/user/ops/onerec_decoder_op.cpp index ede82961920..6057dd6486c 100644 --- a/oneflow/user/ops/onerec_decoder_op.cpp +++ b/oneflow/user/ops/onerec_decoder_op.cpp @@ -14,65 +14,61 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("onerec_decoder") - .Input("in") - .Output("out") - .Attr("key") - .Attr("data_type") - .Attr("static_shape") - .Attr("is_dynamic", false) - .Attr("has_reshape", false) - .Attr("reshape") - .Attr("has_batch_padding", false) - .Attr("batch_padding") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1); - const Shape& static_shape = ctx->Attr("static_shape"); - DimVector dim_vec(1 + static_shape.NumAxes()); - dim_vec[0] = in_tensor.shape().At(0); - FOR_RANGE(int64_t, i, 1, dim_vec.size()) { dim_vec[i] = static_shape.At(i - 1); } - *out_tensor->mut_shape() = Shape(dim_vec); - out_tensor->set_is_dynamic(ctx->Attr("is_dynamic")); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); - CHECK_NOTNULL_OR_RETURN(in_modifier); - in_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), 0) - .Split(user_op::OpArg("out", 0), 0) - .Build(); - return Maybe::Ok(); - }) - .SetOutputArgModifyFn([](user_op::GetOutputArgModifier GetOutputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> Maybe { - // NOTE(yaochi): refer to tensor_buffer_to_list_of_tensors - // In order to support consistent tensor, set set_header_infered_before_compute to false - // only when is_dynamic == true - if (conf.attr("is_dynamic")) { - FOR_RANGE(int64_t, i, 0, conf.output_size("out")) { - user_op::OutputArgModifier* out_i_modifier = GetOutputArgModifierFn("out", i); - CHECK_OR_RETURN(out_i_modifier != nullptr); - out_i_modifier->set_header_infered_before_compute(false); - } - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(in_tensor.data_type() == DataType::kTensorBuffer); - *out_tensor->mut_data_type() = ctx->Attr("data_type"); - return Maybe::Ok(); - }); +/* static */ Maybe OnerecDecoderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(in_tensor.shape().NumAxes() == 1 && in_tensor.shape().At(0) >= 1); + const Shape& static_shape = ctx->Attr("static_shape"); + DimVector dim_vec(1 + static_shape.NumAxes()); + dim_vec[0] = in_tensor.shape().At(0); + FOR_RANGE(int64_t, i, 1, dim_vec.size()) { dim_vec[i] = static_shape.At(i - 1); } + *out_tensor->mut_shape() = Shape(dim_vec); + out_tensor->set_is_dynamic(ctx->Attr("is_dynamic")); + return Maybe::Ok(); } + +/*static*/ Maybe OnerecDecoderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} + +/* static */ Maybe OnerecDecoderOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), 0).Split(user_op::OpArg("out", 0), 0).Build(); + return Maybe::Ok(); +} + +/* static */ Maybe OnerecDecoderOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + user_op::InputArgModifier* in_modifier = GetInputArgModifierFn("in", 0); + CHECK_NOTNULL_OR_RETURN(in_modifier); + in_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/* static */ Maybe OnerecDecoderOp::ModifyOutputArg( + const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + // NOTE(yaochi): refer to tensor_buffer_to_list_of_tensors + // In order to support consistent tensor, set set_header_infered_before_compute to false + // only when is_dynamic == true + if (conf.attr("is_dynamic")) { + FOR_RANGE(int64_t, i, 0, conf.output_size("out")) { + user_op::OutputArgModifier* out_i_modifier = GetOutputArgModifierFn("out", i); + CHECK_OR_RETURN(out_i_modifier != nullptr); + out_i_modifier->set_header_infered_before_compute(false); + } + } + return Maybe::Ok(); +} + +/* static */ Maybe OnerecDecoderOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(in_tensor.data_type() == DataType::kTensorBuffer); + *out_tensor->mut_data_type() = ctx->Attr("data_type"); + return Maybe::Ok(); +} + +} // namespace oneflow diff --git a/oneflow/user/ops/onerec_reader_op.cpp b/oneflow/user/ops/onerec_reader_op.cpp index 697fa7a2722..b98a924b1ff 100644 --- a/oneflow/user/ops/onerec_reader_op.cpp +++ b/oneflow/user/ops/onerec_reader_op.cpp @@ -14,43 +14,34 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("OneRecReader") - .Output("out") - .Attr>("files") - .Attr("batch_size") - .Attr("random_shuffle", false) - .Attr("shuffle_mode", "instance") - .Attr("seed", -1) - .Attr("shuffle_buffer_size", 1024) - .Attr("shuffle_after_epoch", false) - .Attr("verify_example", true) - .SetPhysicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - int32_t local_batch_size = ctx->Attr("batch_size"); - const cfg::SbpParallel& sbp = ctx->SbpParallel4ArgNameAndIndex("out", 0); - int64_t parallel_num = ctx->parallel_ctx().parallel_num(); - CHECK_OR_RETURN(parallel_num == 1 || sbp.has_split_parallel()); - CHECK_EQ_OR_RETURN(local_batch_size % parallel_num, 0); - local_batch_size /= parallel_num; - *out_tensor->mut_shape() = Shape({local_batch_size}); - return Maybe::Ok(); - }) - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - int32_t batch_size = ctx->Attr("batch_size"); - *out_tensor->mut_shape() = Shape({batch_size}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = DataType::kTensorBuffer; - return Maybe::Ok(); - }); +/*static*/ Maybe OneRecReaderOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe OneRecReaderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + int32_t batch_size = ctx->Attr("batch_size"); + *out_tensor->mut_shape() = Shape({batch_size}); + return Maybe::Ok(); +} +/*static*/ Maybe OneRecReaderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + int32_t local_batch_size = ctx->Attr("batch_size"); + const cfg::SbpParallel& sbp = ctx->SbpParallel4ArgNameAndIndex("out", 0); + int64_t parallel_num = ctx->parallel_ctx().parallel_num(); + CHECK_OR_RETURN(parallel_num == 1 || sbp.has_split_parallel()); + CHECK_EQ_OR_RETURN(local_batch_size % parallel_num, 0); + local_batch_size /= parallel_num; + *out_tensor->mut_shape() = Shape({local_batch_size}); + return Maybe::Ok(); +} +/*static*/ Maybe OneRecReaderOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = DataType::kTensorBuffer; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/ones_like_op.cpp b/oneflow/user/ops/ones_like_op.cpp index 9c37d088b9d..cf05b880f87 100644 --- a/oneflow/user/ops/ones_like_op.cpp +++ b/oneflow/user/ops/ones_like_op.cpp @@ -14,35 +14,34 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("ones_like") - .Input("like") - .Output("out") - .SetOutputBufferNum(1) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("like", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& like_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0); - FOR_RANGE(int64_t, i, 0, like_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("like", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("like", 0)) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe OnesLikeOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& like_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0); + FOR_RANGE(int64_t, i, 0, like_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("like", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("like", 0)) + .Broadcast(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe OnesLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); + return Maybe::Ok(); +} +/*static*/ Maybe OnesLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return OnesLikeOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe OnesLikeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("like", 0); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/p2p_comm_op.cpp b/oneflow/user/ops/p2p_comm_op.cpp index 8bd612fa54e..3e7c06fd7b1 100644 --- a/oneflow/user/ops/p2p_comm_op.cpp +++ b/oneflow/user/ops/p2p_comm_op.cpp @@ -15,24 +15,27 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/comm_net_device_infer_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ Maybe SendOp::GetSbp(user_op::SbpContext* ctx) { UNIMPLEMENTED_THEN_RETURN(); } +/*static*/ Maybe SendOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + // Do nothing. + return Maybe::Ok(); +} +/*static*/ Maybe SendOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return SendOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SendOp::InferDataType(user_op::InferContext* ctx) { + // Do nothing. + return Maybe::Ok(); +} +/*static*/ Maybe> SendOp::InferDevice(user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched>(ctx); +} -REGISTER_NO_GRAD_USER_OP("send") - .Input("in") - .Attr("dst_process_id") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - // Do nothing. - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched>) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { UNIMPLEMENTED_THEN_RETURN(); }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - // Do nothing. - return Maybe::Ok(); - }); +namespace { Maybe> GetRecvOutputDeivce(user_op::DeviceInferContext* ctx) { const std::string& device_type = ctx->Attr("device_type"); @@ -40,24 +43,22 @@ Maybe> GetRecvOutputDeivce(user_op::DeviceInferContext* ctx) { return Device::New(device_type, device_id); } -REGISTER_NO_GRAD_USER_OP("recv") - .Output("out") - .Attr("src_process_id") - .Attr("dtype") - .Attr("shape") - .Attr("device_type") - .Attr("device_id") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->Attr("shape"); - return Maybe::Ok(); - }) - .SetDeviceInferFn(DeviceInferFn<&SyncLaunched, &GetRecvOutputDeivce>) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { UNIMPLEMENTED_THEN_RETURN(); }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); - return Maybe::Ok(); - }); - } // namespace +/*static*/ Maybe RecvOp::GetSbp(user_op::SbpContext* ctx) { UNIMPLEMENTED_THEN_RETURN(); } +/*static*/ Maybe RecvOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->Attr("shape"); + return Maybe::Ok(); +} +/*static*/ Maybe RecvOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return SendOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe RecvOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->Attr("dtype"); + return Maybe::Ok(); +} +/*static*/ Maybe> RecvOp::InferDevice(user_op::DeviceInferContext* ctx) { + return DeviceInferFn<&SyncLaunched, &GetRecvOutputDeivce>(ctx); +} + } // namespace oneflow diff --git a/oneflow/user/ops/pack_op.cpp b/oneflow/user/ops/pack_op.cpp index e28529d542c..b5ae5c75a74 100644 --- a/oneflow/user/ops/pack_op.cpp +++ b/oneflow/user/ops/pack_op.cpp @@ -14,61 +14,59 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ Maybe PackOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe PackOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + const Shape& in_shape = in_desc.shape(); + const int32_t pack_num = ctx->Attr("pack_num"); + CHECK_GT_OR_RETURN(pack_num, 0); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_is_dynamic() = in_desc.is_dynamic(); + if (in_shape.NumAxes() > 0) { + *out_desc->mut_shape() = in_shape; + out_desc->mut_shape()->Set(0, in_shape.At(0) * pack_num); + } else { + // NOTE(chengcheng): for Scalar input pack + CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), 1); + *out_desc->mut_shape() = Shape({pack_num}); + } + return Maybe::Ok(); +} -REGISTER_USER_OP("pack") - .Input("in") - .Output("out") - .Attr("pack_num") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - const Shape& in_shape = in_desc.shape(); - const int32_t pack_num = ctx->Attr("pack_num"); - CHECK_GT_OR_RETURN(pack_num, 0); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_is_dynamic() = in_desc.is_dynamic(); - if (in_shape.NumAxes() > 0) { - *out_desc->mut_shape() = in_shape; - out_desc->mut_shape()->Set(0, in_shape.At(0) * pack_num); - } else { - // NOTE(chengcheng): for Scalar input pack - CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), 1); - *out_desc->mut_shape() = Shape({pack_num}); - } - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetOutputBlobTimeShapeInferFn( - [](user_op::InferOutputBlobTimeShapeFnContext* ctx) -> Maybe { - const int32_t pack_num = ctx->user_op_conf().attr("pack_num"); - DimVector time_shape_dim_vec = ctx->TimeShape4InputArgNameAndIndex("in", 0).dim_vec(); - CHECK_OR_RETURN(!time_shape_dim_vec.empty()); - CHECK_EQ_OR_RETURN(time_shape_dim_vec.back(), pack_num); - time_shape_dim_vec.pop_back(); - if (time_shape_dim_vec.empty()) { time_shape_dim_vec.emplace_back(1); } - *ctx->mut_output_blob_time_shape() = Shape(time_shape_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe PackOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return PackOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe PackOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe PackOp::InferOutputBlobTimeShape( + user_op::InferOutputBlobTimeShapeFnContext* ctx) { + const int32_t pack_num = ctx->user_op_conf().attr("pack_num"); + DimVector time_shape_dim_vec = ctx->TimeShape4InputArgNameAndIndex("in", 0).dim_vec(); + CHECK_OR_RETURN(!time_shape_dim_vec.empty()); + CHECK_EQ_OR_RETURN(time_shape_dim_vec.back(), pack_num); + time_shape_dim_vec.pop_back(); + if (time_shape_dim_vec.empty()) { time_shape_dim_vec.emplace_back(1); } + *ctx->mut_output_blob_time_shape() = Shape(time_shape_dim_vec); + return Maybe::Ok(); +} + +namespace { REGISTER_USER_OP_GRAD("pack").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { diff --git a/oneflow/user/ops/pad_op.cpp b/oneflow/user/ops/pad_op.cpp index a446b29ce90..b222bdc2ad7 100644 --- a/oneflow/user/ops/pad_op.cpp +++ b/oneflow/user/ops/pad_op.cpp @@ -15,86 +15,73 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/balanced_splitter.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("pad") - .Input("x") - .Output("y") - .Attr>("padding_before") - .Attr>("padding_after") - .Attr("floating_constant_value") - .Attr("integral_constant_value") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const auto& padding_before = ctx->Attr>("padding_before"); - const auto& padding_after = ctx->Attr>("padding_after"); - CHECK_EQ_OR_RETURN(padding_before.size(), x_shape.NumAxes()); - CHECK_EQ_OR_RETURN(padding_after.size(), x_shape.NumAxes()); - DimVector y_dim_vec(x_shape.NumAxes()); - FOR_RANGE(int64_t, i, 0, x_shape.NumAxes()) { - y_dim_vec[i] = x_shape.At(i) + padding_before[i] + padding_after[i]; - } - *ctx->OutputShape("y", 0) = Shape(y_dim_vec); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - const auto& padding_before = ctx->Attr>("padding_before"); - const auto& padding_after = ctx->Attr>("padding_after"); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - if (padding_before[i] == 0 && padding_after[i] == 0) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("y", 0), i) - .Build(); - } - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe PadOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + const auto& padding_before = ctx->Attr>("padding_before"); + const auto& padding_after = ctx->Attr>("padding_after"); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + if (padding_before[i] == 0 && padding_after[i] == 0) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); + } + } + return Maybe::Ok(); +} +/*static*/ Maybe PadOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const auto& padding_before = ctx->Attr>("padding_before"); + const auto& padding_after = ctx->Attr>("padding_after"); + CHECK_EQ_OR_RETURN(padding_before.size(), x_shape.NumAxes()); + CHECK_EQ_OR_RETURN(padding_after.size(), x_shape.NumAxes()); + DimVector y_dim_vec(x_shape.NumAxes()); + FOR_RANGE(int64_t, i, 0, x_shape.NumAxes()) { + y_dim_vec[i] = x_shape.At(i) + padding_before[i] + padding_after[i]; + } + *ctx->OutputShape("y", 0) = Shape(y_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe PadOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return PadOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe PadOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("pad_grad") - .Input("dy") - .Output("dx") - .Attr>("padding_before") - .Attr>("padding_after") - .Attr("floating_constant_value") - .Attr("integral_constant_value") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - const auto& padding_before = ctx->Attr>("padding_before"); - const auto& padding_after = ctx->Attr>("padding_after"); - CHECK_EQ_OR_RETURN(padding_before.size(), dy_shape.NumAxes()); - CHECK_EQ_OR_RETURN(padding_after.size(), dy_shape.NumAxes()); - DimVector dx_dim_vec(dy_shape.NumAxes()); - FOR_RANGE(int64_t, i, 0, dy_shape.NumAxes()) { - dx_dim_vec[i] = dy_shape.At(i) - padding_before[i] - padding_after[i]; - } - *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); - const auto& padding_before = ctx->Attr>("padding_before"); - const auto& padding_after = ctx->Attr>("padding_after"); - FOR_RANGE(int64_t, i, 0, dy_tensor.shape().NumAxes()) { - if (padding_before[i] == 0 && padding_after[i] == 0) { - ctx->NewBuilder() - .Split(user_op::OpArg("dx", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Build(); - } - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe PadGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& dy_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0); + const auto& padding_before = ctx->Attr>("padding_before"); + const auto& padding_after = ctx->Attr>("padding_after"); + FOR_RANGE(int64_t, i, 0, dy_tensor.shape().NumAxes()) { + if (padding_before[i] == 0 && padding_after[i] == 0) { + ctx->NewBuilder().Split(user_op::OpArg("dx", 0), i).Split(user_op::OpArg("dy", 0), i).Build(); + } + } + return Maybe::Ok(); +} +/*static*/ Maybe PadGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + const auto& padding_before = ctx->Attr>("padding_before"); + const auto& padding_after = ctx->Attr>("padding_after"); + CHECK_EQ_OR_RETURN(padding_before.size(), dy_shape.NumAxes()); + CHECK_EQ_OR_RETURN(padding_after.size(), dy_shape.NumAxes()); + DimVector dx_dim_vec(dy_shape.NumAxes()); + FOR_RANGE(int64_t, i, 0, dy_shape.NumAxes()) { + dx_dim_vec[i] = dy_shape.At(i) - padding_before[i] - padding_after[i]; + } + *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe PadGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return PadGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe PadGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("pad").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/padding_ops.cpp b/oneflow/user/ops/padding_ops.cpp index 46d97d8520c..969bd1a5721 100644 --- a/oneflow/user/ops/padding_ops.cpp +++ b/oneflow/user/ops/padding_ops.cpp @@ -16,6 +16,7 @@ limitations under the License. #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -46,89 +47,82 @@ Maybe GetOpGradSbpSignature(user_op::SbpContext* ctx) { } // namespace -REGISTER_USER_OP("reflection_pad2d") - .Input("x") - .Output("y") - .Attr>("padding") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const auto& padding = ctx->Attr>("padding"); - CHECK_EQ_OR_RETURN(padding.size(), x_shape.NumAxes()); - const int64_t n_idx = 0; - const int64_t c_idx = 1; - const int64_t h_idx = 2; - const int64_t w_idx = 3; - - // Ensure the padding size is less than the input dimension. - CHECK_LT_OR_RETURN(padding[0], x_shape.At(w_idx)); - CHECK_LT_OR_RETURN(padding[1], x_shape.At(w_idx)); - CHECK_LT_OR_RETURN(padding[2], x_shape.At(h_idx)); - CHECK_LT_OR_RETURN(padding[3], x_shape.At(h_idx)); - - DimVector y_dim_vec(x_shape.NumAxes()); - const int64_t h_x = x_shape.At(h_idx); - const int64_t w_x = x_shape.At(w_idx); - - y_dim_vec[n_idx] = x_shape.At(n_idx); - y_dim_vec[c_idx] = x_shape.At(c_idx); - y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; - y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; - - *ctx->OutputShape("y", 0) = Shape(y_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetOpSbpSignature) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); - CHECK_NOTNULL_OR_RETURN(x_modifier); - x_modifier->set_requires_grad(true); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ReflectionPad2DOp::GetSbp(user_op::SbpContext* ctx) { + return GetOpSbpSignature(ctx); +} +/*static*/ Maybe ReflectionPad2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const auto& padding = ctx->Attr>("padding"); + CHECK_EQ_OR_RETURN(padding.size(), x_shape.NumAxes()); + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t h_idx = 2; + const int64_t w_idx = 3; + + // Ensure the padding size is less than the input dimension. + CHECK_LT_OR_RETURN(padding[0], x_shape.At(w_idx)); + CHECK_LT_OR_RETURN(padding[1], x_shape.At(w_idx)); + CHECK_LT_OR_RETURN(padding[2], x_shape.At(h_idx)); + CHECK_LT_OR_RETURN(padding[3], x_shape.At(h_idx)); + + DimVector y_dim_vec(x_shape.NumAxes()); + const int64_t h_x = x_shape.At(h_idx); + const int64_t w_x = x_shape.At(w_idx); + + y_dim_vec[n_idx] = x_shape.At(n_idx); + y_dim_vec[c_idx] = x_shape.At(c_idx); + y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; + y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; + + *ctx->OutputShape("y", 0) = Shape(y_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe ReflectionPad2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return ReflectionPad2DOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ReflectionPad2DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe ReflectionPad2DOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); + CHECK_NOTNULL_OR_RETURN(x_modifier); + x_modifier->set_requires_grad(true); + return Maybe::Ok(); +} -REGISTER_USER_OP("reflection_pad2d_grad") - .Input("dy") - .Output("dx") - .Attr>("padding") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - const auto& padding = ctx->Attr>("padding"); - CHECK_EQ_OR_RETURN(padding.size(), dy_shape.NumAxes()); - const int64_t n_idx = 0; - const int64_t c_idx = 1; - const int64_t h_idx = 2; - const int64_t w_idx = 3; - - DimVector dx_dim_vec(dy_shape.NumAxes()); - int64_t h_dy, w_dy; - h_dy = dy_shape.At(h_idx); - w_dy = dy_shape.At(w_idx); - - dx_dim_vec[n_idx] = dy_shape.At(0); - dx_dim_vec[c_idx] = dy_shape.At(1); - dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; - dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; - - *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetOpGradSbpSignature) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ReflectionPad2DGradOp::GetSbp(user_op::SbpContext* ctx) { + return GetOpGradSbpSignature(ctx); +} +/*static*/ Maybe ReflectionPad2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + const auto& padding = ctx->Attr>("padding"); + CHECK_EQ_OR_RETURN(padding.size(), dy_shape.NumAxes()); + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t h_idx = 2; + const int64_t w_idx = 3; + + DimVector dx_dim_vec(dy_shape.NumAxes()); + int64_t h_dy = dy_shape.At(h_idx); + int64_t w_dy = dy_shape.At(w_idx); + + dx_dim_vec[n_idx] = dy_shape.At(0); + dx_dim_vec[c_idx] = dy_shape.At(1); + dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; + dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; + + *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe ReflectionPad2DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return ReflectionPad2DGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ReflectionPad2DGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("reflection_pad2d") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -147,83 +141,76 @@ REGISTER_USER_OP_GRAD("reflection_pad2d") return Maybe::Ok(); }); -REGISTER_USER_OP("replication_pad2d") - .Input("x") - .Output("y") - .Attr>("padding") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const auto& padding = ctx->Attr>("padding"); - CHECK_EQ_OR_RETURN(padding.size(), x_shape.NumAxes()); - const int64_t n_idx = 0; - const int64_t c_idx = 1; - const int64_t h_idx = 2; - const int64_t w_idx = 3; - - DimVector y_dim_vec(x_shape.NumAxes()); - const int64_t h_x = x_shape.At(h_idx); - const int64_t w_x = x_shape.At(w_idx); - - y_dim_vec[n_idx] = x_shape.At(n_idx); - y_dim_vec[c_idx] = x_shape.At(c_idx); - y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; - y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; - - *ctx->OutputShape("y", 0) = Shape(y_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetOpSbpSignature) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); - CHECK_NOTNULL_OR_RETURN(x_modifier); - x_modifier->set_requires_grad(true); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ReplicationPad2DOp::GetSbp(user_op::SbpContext* ctx) { + return GetOpSbpSignature(ctx); +} +/*static*/ Maybe ReplicationPad2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const auto& padding = ctx->Attr>("padding"); + CHECK_EQ_OR_RETURN(padding.size(), x_shape.NumAxes()); + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t h_idx = 2; + const int64_t w_idx = 3; + + DimVector y_dim_vec(x_shape.NumAxes()); + const int64_t h_x = x_shape.At(h_idx); + const int64_t w_x = x_shape.At(w_idx); + + y_dim_vec[n_idx] = x_shape.At(n_idx); + y_dim_vec[c_idx] = x_shape.At(c_idx); + y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; + y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; + + *ctx->OutputShape("y", 0) = Shape(y_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe ReplicationPad2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return ReplicationPad2DOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ReplicationPad2DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe ReplicationPad2DOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); + CHECK_NOTNULL_OR_RETURN(x_modifier); + x_modifier->set_requires_grad(true); + return Maybe::Ok(); +} -REGISTER_USER_OP("replication_pad2d_grad") - .Input("dy") - .Output("dx") - .Attr>("padding") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - const auto& padding = ctx->Attr>("padding"); - CHECK_EQ_OR_RETURN(padding.size(), dy_shape.NumAxes()); - const int64_t n_idx = 0; - const int64_t c_idx = 1; - const int64_t h_idx = 2; - const int64_t w_idx = 3; - - DimVector dx_dim_vec(dy_shape.NumAxes()); - int64_t h_dy, w_dy; - h_dy = dy_shape.At(h_idx); - w_dy = dy_shape.At(w_idx); - - dx_dim_vec[n_idx] = dy_shape.At(0); - dx_dim_vec[c_idx] = dy_shape.At(1); - dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; - dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; - - *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetOpGradSbpSignature) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ReplicationPad2DGradOp::GetSbp(user_op::SbpContext* ctx) { + return GetOpGradSbpSignature(ctx); +} +/*static*/ Maybe ReplicationPad2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + const auto& padding = ctx->Attr>("padding"); + CHECK_EQ_OR_RETURN(padding.size(), dy_shape.NumAxes()); + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t h_idx = 2; + const int64_t w_idx = 3; + + DimVector dx_dim_vec(dy_shape.NumAxes()); + int64_t h_dy = dy_shape.At(h_idx); + int64_t w_dy = dy_shape.At(w_idx); + + dx_dim_vec[n_idx] = dy_shape.At(0); + dx_dim_vec[c_idx] = dy_shape.At(1); + dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; + dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; + + *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe ReplicationPad2DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return ReplicationPad2DGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ReplicationPad2DGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("replication_pad2d") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -242,83 +229,72 @@ REGISTER_USER_OP_GRAD("replication_pad2d") return Maybe::Ok(); }); -REGISTER_USER_OP("constant_pad1d") - .Input("x") - .Output("y") - .Attr>("padding") - .Attr("floating_value") - .Attr("integral_value") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const auto& padding = ctx->Attr>("padding"); - CHECK_EQ_OR_RETURN(x_shape.NumAxes(), 3); - CHECK_EQ_OR_RETURN(padding.size(), 2); - const int64_t n_idx = 0; - const int64_t c_idx = 1; - const int64_t w_idx = 2; - - DimVector y_dim_vec(x_shape.NumAxes()); - const int64_t w_x = x_shape.At(w_idx); - - y_dim_vec[n_idx] = x_shape.At(n_idx); - y_dim_vec[c_idx] = x_shape.At(c_idx); - y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; - - *ctx->OutputShape("y", 0) = Shape(y_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetOpSbpSignature) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); - CHECK_NOTNULL_OR_RETURN(x_modifier); - x_modifier->set_requires_grad(true); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ConstantPad1DOp::GetSbp(user_op::SbpContext* ctx) { + return GetOpSbpSignature(ctx); +} +/*static*/ Maybe ConstantPad1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const auto& padding = ctx->Attr>("padding"); + CHECK_EQ_OR_RETURN(x_shape.NumAxes(), 3); + CHECK_EQ_OR_RETURN(padding.size(), 2); + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t w_idx = 2; + + DimVector y_dim_vec(x_shape.NumAxes()); + const int64_t w_x = x_shape.At(w_idx); + + y_dim_vec[n_idx] = x_shape.At(n_idx); + y_dim_vec[c_idx] = x_shape.At(c_idx); + y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; + + *ctx->OutputShape("y", 0) = Shape(y_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe ConstantPad1DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return ConstantPad1DOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ConstantPad1DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe ConstantPad1DOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); + CHECK_NOTNULL_OR_RETURN(x_modifier); + x_modifier->set_requires_grad(true); + return Maybe::Ok(); +} -REGISTER_USER_OP("constant_pad1d_grad") - .Input("dy") - .Output("dx") - .Attr>("padding") - .Attr("floating_value") - .Attr("integral_value") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - const auto& padding = ctx->Attr>("padding"); - CHECK_EQ_OR_RETURN(dy_shape.NumAxes(), 3); - CHECK_EQ_OR_RETURN(padding.size(), 2); - const int64_t n_idx = 0; - const int64_t c_idx = 1; - const int64_t w_idx = 2; - - DimVector dx_dim_vec(dy_shape.NumAxes()); - int64_t w_dy; - w_dy = dy_shape.At(w_idx); - - dx_dim_vec[n_idx] = dy_shape.At(0); - dx_dim_vec[c_idx] = dy_shape.At(1); - dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; - - *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetOpGradSbpSignature) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ConstantPad1DGradOp::GetSbp(user_op::SbpContext* ctx) { + return GetOpGradSbpSignature(ctx); +} +/*static*/ Maybe ConstantPad1DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + const auto& padding = ctx->Attr>("padding"); + CHECK_EQ_OR_RETURN(dy_shape.NumAxes(), 3); + CHECK_EQ_OR_RETURN(padding.size(), 2); + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t w_idx = 2; + + DimVector dx_dim_vec(dy_shape.NumAxes()); + int64_t w_dy = dy_shape.At(w_idx); + + dx_dim_vec[n_idx] = dy_shape.At(0); + dx_dim_vec[c_idx] = dy_shape.At(1); + dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; + + *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe ConstantPad1DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return ConstantPad1DGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ConstantPad1DGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("constant_pad1d") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -339,87 +315,76 @@ REGISTER_USER_OP_GRAD("constant_pad1d") return Maybe::Ok(); }); -REGISTER_USER_OP("constant_pad2d") - .Input("x") - .Output("y") - .Attr>("padding") - .Attr("floating_value") - .Attr("integral_value") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const auto& padding = ctx->Attr>("padding"); - CHECK_EQ_OR_RETURN(padding.size(), x_shape.NumAxes()); - const int64_t n_idx = 0; - const int64_t c_idx = 1; - const int64_t h_idx = 2; - const int64_t w_idx = 3; - - DimVector y_dim_vec(x_shape.NumAxes()); - const int64_t h_x = x_shape.At(h_idx); - const int64_t w_x = x_shape.At(w_idx); - - y_dim_vec[n_idx] = x_shape.At(n_idx); - y_dim_vec[c_idx] = x_shape.At(c_idx); - y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; - y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; - - *ctx->OutputShape("y", 0) = Shape(y_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetOpSbpSignature) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); - CHECK_NOTNULL_OR_RETURN(x_modifier); - x_modifier->set_requires_grad(true); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ConstantPad2DOp::GetSbp(user_op::SbpContext* ctx) { + return GetOpSbpSignature(ctx); +} +/*static*/ Maybe ConstantPad2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const auto& padding = ctx->Attr>("padding"); + CHECK_EQ_OR_RETURN(padding.size(), x_shape.NumAxes()); + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t h_idx = 2; + const int64_t w_idx = 3; + + DimVector y_dim_vec(x_shape.NumAxes()); + const int64_t h_x = x_shape.At(h_idx); + const int64_t w_x = x_shape.At(w_idx); + + y_dim_vec[n_idx] = x_shape.At(n_idx); + y_dim_vec[c_idx] = x_shape.At(c_idx); + y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; + y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; + + *ctx->OutputShape("y", 0) = Shape(y_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe ConstantPad2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return ConstantPad2DOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ConstantPad2DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe ConstantPad2DOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); + CHECK_NOTNULL_OR_RETURN(x_modifier); + x_modifier->set_requires_grad(true); + return Maybe::Ok(); +} -REGISTER_USER_OP("constant_pad2d_grad") - .Input("dy") - .Output("dx") - .Attr>("padding") - .Attr("floating_value") - .Attr("integral_value") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - const auto& padding = ctx->Attr>("padding"); - CHECK_EQ_OR_RETURN(padding.size(), dy_shape.NumAxes()); - const int64_t n_idx = 0; - const int64_t c_idx = 1; - const int64_t h_idx = 2; - const int64_t w_idx = 3; - - DimVector dx_dim_vec(dy_shape.NumAxes()); - int64_t h_dy, w_dy; - h_dy = dy_shape.At(h_idx); - w_dy = dy_shape.At(w_idx); - - dx_dim_vec[n_idx] = dy_shape.At(0); - dx_dim_vec[c_idx] = dy_shape.At(1); - dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; - dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; - - *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetOpGradSbpSignature) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ConstantPad2DGradOp::GetSbp(user_op::SbpContext* ctx) { + return GetOpGradSbpSignature(ctx); +} +/*static*/ Maybe ConstantPad2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + const auto& padding = ctx->Attr>("padding"); + CHECK_EQ_OR_RETURN(padding.size(), dy_shape.NumAxes()); + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t h_idx = 2; + const int64_t w_idx = 3; + + DimVector dx_dim_vec(dy_shape.NumAxes()); + int64_t h_dy = dy_shape.At(h_idx); + int64_t w_dy = dy_shape.At(w_idx); + + dx_dim_vec[n_idx] = dy_shape.At(0); + dx_dim_vec[c_idx] = dy_shape.At(1); + dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; + dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; + + *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe ConstantPad2DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return ConstantPad2DGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ConstantPad2DGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("constant_pad2d") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -440,95 +405,84 @@ REGISTER_USER_OP_GRAD("constant_pad2d") return Maybe::Ok(); }); -REGISTER_USER_OP("constant_pad3d") - .Input("x") - .Output("y") - .Attr>("padding") - .Attr("floating_value") - .Attr("integral_value") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const auto& padding = ctx->Attr>("padding"); - CHECK_EQ_OR_RETURN(x_shape.NumAxes(), 5); - // only support NCDHW format input tensor for now ! - // for NCDHW format, index of num,channel,depth,height,width is 0,1,2,3,4 - const int64_t n_idx = 0; - const int64_t c_idx = 1; - const int64_t d_idx = 2; - const int64_t h_idx = 3; - const int64_t w_idx = 4; - - DimVector y_dim_vec(x_shape.NumAxes()); - const int64_t d_x = x_shape.At(d_idx); - const int64_t h_x = x_shape.At(h_idx); - const int64_t w_x = x_shape.At(w_idx); - - y_dim_vec[n_idx] = x_shape.At(n_idx); - y_dim_vec[c_idx] = x_shape.At(c_idx); - y_dim_vec[d_idx] = d_x + padding[4] + padding[5]; - y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; - y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; - - *ctx->OutputShape("y", 0) = Shape(y_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetOpSbpSignature) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); - CHECK_NOTNULL_OR_RETURN(x_modifier); - x_modifier->set_requires_grad(true); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ConstantPad3DOp::GetSbp(user_op::SbpContext* ctx) { + return GetOpSbpSignature(ctx); +} +/*static*/ Maybe ConstantPad3DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const auto& padding = ctx->Attr>("padding"); + CHECK_EQ_OR_RETURN(x_shape.NumAxes(), 5); + // only support NCDHW format input tensor for now ! + // for NCDHW format, index of num,channel,depth,height,width is 0,1,2,3,4 + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t d_idx = 2; + const int64_t h_idx = 3; + const int64_t w_idx = 4; + + DimVector y_dim_vec(x_shape.NumAxes()); + const int64_t d_x = x_shape.At(d_idx); + const int64_t h_x = x_shape.At(h_idx); + const int64_t w_x = x_shape.At(w_idx); + + y_dim_vec[n_idx] = x_shape.At(n_idx); + y_dim_vec[c_idx] = x_shape.At(c_idx); + y_dim_vec[d_idx] = d_x + padding[4] + padding[5]; + y_dim_vec[h_idx] = h_x + padding[2] + padding[3]; + y_dim_vec[w_idx] = w_x + padding[0] + padding[1]; + + *ctx->OutputShape("y", 0) = Shape(y_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe ConstantPad3DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return ConstantPad3DOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ConstantPad3DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe ConstantPad3DOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* x_modifier = GetInputArgModifierFn("x", 0); + CHECK_NOTNULL_OR_RETURN(x_modifier); + x_modifier->set_requires_grad(true); + return Maybe::Ok(); +} -REGISTER_USER_OP("constant_pad3d_grad") - .Input("dy") - .Output("dx") - .Attr>("padding") - .Attr("floating_value") - .Attr("integral_value") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - const auto& padding = ctx->Attr>("padding"); - CHECK_EQ_OR_RETURN(dy_shape.NumAxes(), 5); - const int64_t n_idx = 0; - const int64_t c_idx = 1; - const int64_t d_idx = 2; - const int64_t h_idx = 3; - const int64_t w_idx = 4; - - DimVector dx_dim_vec(dy_shape.NumAxes()); - int64_t d_dy, h_dy, w_dy; - d_dy = dy_shape.At(d_idx); - h_dy = dy_shape.At(h_idx); - w_dy = dy_shape.At(w_idx); - - dx_dim_vec[n_idx] = dy_shape.At(0); - dx_dim_vec[c_idx] = dy_shape.At(1); - dx_dim_vec[d_idx] = d_dy - padding[4] - padding[5]; - dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; - dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; - - *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(GetOpGradSbpSignature) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ConstantPad3DGradOp::GetSbp(user_op::SbpContext* ctx) { + return GetOpGradSbpSignature(ctx); +} +/*static*/ Maybe ConstantPad3DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + const auto& padding = ctx->Attr>("padding"); + CHECK_EQ_OR_RETURN(dy_shape.NumAxes(), 5); + const int64_t n_idx = 0; + const int64_t c_idx = 1; + const int64_t d_idx = 2; + const int64_t h_idx = 3; + const int64_t w_idx = 4; + + DimVector dx_dim_vec(dy_shape.NumAxes()); + int64_t d_dy = dy_shape.At(d_idx); + int64_t h_dy = dy_shape.At(h_idx); + int64_t w_dy = dy_shape.At(w_idx); + + dx_dim_vec[n_idx] = dy_shape.At(0); + dx_dim_vec[c_idx] = dy_shape.At(1); + dx_dim_vec[d_idx] = d_dy - padding[4] - padding[5]; + dx_dim_vec[h_idx] = h_dy - padding[2] - padding[3]; + dx_dim_vec[w_idx] = w_dy - padding[0] - padding[1]; + + *ctx->OutputShape("dx", 0) = Shape(dx_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe ConstantPad3DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return ConstantPad3DGradOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ConstantPad3DGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("constant_pad3d") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/parallel_cast_op.cpp b/oneflow/user/ops/parallel_cast_op.cpp index b31cf919f4b..b4762acf2d4 100644 --- a/oneflow/user/ops/parallel_cast_op.cpp +++ b/oneflow/user/ops/parallel_cast_op.cpp @@ -15,49 +15,50 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/operator.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("parallel_cast") - .Input("in") - .Output("out") - .Attr("sbp_parallel", "") - .Attr("grad_sbp_parallel", "") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetSbpSignatureInferFn([](user_op::InferSbpSignatureFnContext* ctx) -> Maybe { - auto* bn2sbp = ctx->mutable_sbp_signature()->mutable_bn_in_op2sbp_parallel(); - const std::string& ibn = GenRepeatedBn("in", 0); - const std::string& obn = GenRepeatedBn("out", 0); - const auto& sbp_parallel_str = ctx->Attr("sbp_parallel"); - if (sbp_parallel_str.empty()) { - const auto& sbp_parallel = ctx->SbpParallelHint4InputArgNameAndIndex("in", 0); - (*bn2sbp)[ibn] = sbp_parallel; - (*bn2sbp)[obn] = sbp_parallel; - } else { - cfg::SbpParallel sbp_parallel; - CHECK_OR_RETURN(ParseSbpParallelFromString(sbp_parallel_str, &sbp_parallel)) - << "invalid sbp_parallel: " << sbp_parallel_str; - if (sbp_parallel.has_split_parallel()) { - int64_t split_axis = sbp_parallel.split_parallel().axis(); - const auto& in_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - int64_t num_axes = in_desc.shape().NumAxes(); - CHECK_GE_OR_RETURN(split_axis, 0); - CHECK_LT_OR_RETURN(split_axis, num_axes); - } - (*bn2sbp)[ibn] = sbp_parallel; - (*bn2sbp)[obn] = sbp_parallel; - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/*static*/ Maybe ParallelCastOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe ParallelCastOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe ParallelCastOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return ParallelCastOp::InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ParallelCastOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe ParallelCastOp::InferSbpSignature(user_op::InferSbpSignatureFnContext* ctx) { + auto* bn2sbp = ctx->mutable_sbp_signature()->mutable_bn_in_op2sbp_parallel(); + const std::string& ibn = GenRepeatedBn("in", 0); + const std::string& obn = GenRepeatedBn("out", 0); + const auto& sbp_parallel_str = ctx->Attr("sbp_parallel"); + if (sbp_parallel_str.empty()) { + const auto& sbp_parallel = ctx->SbpParallelHint4InputArgNameAndIndex("in", 0); + (*bn2sbp)[ibn] = sbp_parallel; + (*bn2sbp)[obn] = sbp_parallel; + } else { + cfg::SbpParallel sbp_parallel; + CHECK_OR_RETURN(ParseSbpParallelFromString(sbp_parallel_str, &sbp_parallel)) + << "invalid sbp_parallel: " << sbp_parallel_str; + if (sbp_parallel.has_split_parallel()) { + int64_t split_axis = sbp_parallel.split_parallel().axis(); + const auto& in_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + int64_t num_axes = in_desc.shape().NumAxes(); + CHECK_GE_OR_RETURN(split_axis, 0); + CHECK_LT_OR_RETURN(split_axis, num_axes); + } + (*bn2sbp)[ibn] = sbp_parallel; + (*bn2sbp)[obn] = sbp_parallel; + } + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("parallel_cast") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { diff --git a/oneflow/user/ops/partial_fc_sample_op.cpp b/oneflow/user/ops/partial_fc_sample_op.cpp index b40d6f94d8d..1798e91fe6d 100644 --- a/oneflow/user/ops/partial_fc_sample_op.cpp +++ b/oneflow/user/ops/partial_fc_sample_op.cpp @@ -14,127 +14,119 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("distributed_partial_fc_sample") - .Input("weight") - .Input("label") - .Output("mapped_label") - .Output("sampled_label") - .Output("sampled_weight") - .Attr("num_sample") - .Attr("seed", -1) - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const int64_t num_sample = ctx->Attr("num_sample"); - const user_op::TensorDesc& weight = ctx->InputTensorDesc("weight", 0); - const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); - user_op::TensorDesc* mapped_label = ctx->OutputTensorDesc("mapped_label", 0); - user_op::TensorDesc* sampled_weight = ctx->OutputTensorDesc("sampled_weight", 0); - user_op::TensorDesc* sampled_label = ctx->OutputTensorDesc("sampled_label", 0); - *mapped_label->mut_shape() = label.shape(); - *mapped_label->mut_is_dynamic() = label.is_dynamic(); - *sampled_weight->mut_shape() = weight.shape(); - sampled_weight->mut_shape()->Set(0, num_sample); - *sampled_weight->mut_is_dynamic() = weight.is_dynamic(); - *sampled_label->mut_shape() = label.shape(); - sampled_label->mut_shape()->Set(0, num_sample); - *sampled_label->mut_is_dynamic() = label.is_dynamic(); - return Maybe::Ok(); - }) - .SetPhysicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const int64_t num_sample = ctx->Attr("num_sample"); - const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); - CHECK_EQ_OR_RETURN(num_sample % parallel_num, 0); - const int64_t num_sample_per_rank = num_sample / parallel_num; - const user_op::TensorDesc& weight = ctx->InputTensorDesc("weight", 0); - const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); - user_op::TensorDesc* mapped_label = ctx->OutputTensorDesc("mapped_label", 0); - user_op::TensorDesc* sampled_weight = ctx->OutputTensorDesc("sampled_weight", 0); - user_op::TensorDesc* sampled_label = ctx->OutputTensorDesc("sampled_label", 0); - *mapped_label->mut_shape() = label.shape(); - *mapped_label->mut_is_dynamic() = label.is_dynamic(); - *sampled_weight->mut_shape() = weight.shape(); - sampled_weight->mut_shape()->Set(0, num_sample_per_rank); - *sampled_weight->mut_is_dynamic() = weight.is_dynamic(); - *sampled_label->mut_shape() = label.shape(); - sampled_label->mut_shape()->Set(0, num_sample_per_rank); - *sampled_label->mut_is_dynamic() = label.is_dynamic(); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); - CHECK_NOTNULL_OR_RETURN(label_modifier); - label_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("weight", 0), 0) - .Broadcast(user_op::OpArg("label", 0)) - .Broadcast(user_op::OpArg("mapped_label", 0)) - .Split(user_op::OpArg("sampled_label", 0), 0) - .Split(user_op::OpArg("sampled_weight", 0), 0) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("mapped_label", 0) = ctx->InputDType("label", 0); - *ctx->OutputDType("sampled_weight", 0) = ctx->InputDType("weight", 0); - *ctx->OutputDType("sampled_label", 0) = ctx->InputDType("label", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe DistributedPartialFcSampleOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("weight", 0), 0) + .Broadcast(user_op::OpArg("label", 0)) + .Broadcast(user_op::OpArg("mapped_label", 0)) + .Split(user_op::OpArg("sampled_label", 0), 0) + .Split(user_op::OpArg("sampled_weight", 0), 0) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe DistributedPartialFcSampleOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const int64_t num_sample = ctx->Attr("num_sample"); + const user_op::TensorDesc& weight = ctx->InputTensorDesc("weight", 0); + const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); + user_op::TensorDesc* mapped_label = ctx->OutputTensorDesc("mapped_label", 0); + user_op::TensorDesc* sampled_weight = ctx->OutputTensorDesc("sampled_weight", 0); + user_op::TensorDesc* sampled_label = ctx->OutputTensorDesc("sampled_label", 0); + *mapped_label->mut_shape() = label.shape(); + *mapped_label->mut_is_dynamic() = label.is_dynamic(); + *sampled_weight->mut_shape() = weight.shape(); + sampled_weight->mut_shape()->Set(0, num_sample); + *sampled_weight->mut_is_dynamic() = weight.is_dynamic(); + *sampled_label->mut_shape() = label.shape(); + sampled_label->mut_shape()->Set(0, num_sample); + *sampled_label->mut_is_dynamic() = label.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ Maybe DistributedPartialFcSampleOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + const int64_t num_sample = ctx->Attr("num_sample"); + const int64_t parallel_num = ctx->parallel_ctx().parallel_num(); + CHECK_EQ_OR_RETURN(num_sample % parallel_num, 0); + const int64_t num_sample_per_rank = num_sample / parallel_num; + const user_op::TensorDesc& weight = ctx->InputTensorDesc("weight", 0); + const user_op::TensorDesc& label = ctx->InputTensorDesc("label", 0); + user_op::TensorDesc* mapped_label = ctx->OutputTensorDesc("mapped_label", 0); + user_op::TensorDesc* sampled_weight = ctx->OutputTensorDesc("sampled_weight", 0); + user_op::TensorDesc* sampled_label = ctx->OutputTensorDesc("sampled_label", 0); + *mapped_label->mut_shape() = label.shape(); + *mapped_label->mut_is_dynamic() = label.is_dynamic(); + *sampled_weight->mut_shape() = weight.shape(); + sampled_weight->mut_shape()->Set(0, num_sample_per_rank); + *sampled_weight->mut_is_dynamic() = weight.is_dynamic(); + *sampled_label->mut_shape() = label.shape(); + sampled_label->mut_shape()->Set(0, num_sample_per_rank); + *sampled_label->mut_is_dynamic() = label.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ Maybe DistributedPartialFcSampleOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("mapped_label", 0) = ctx->InputDType("label", 0); + *ctx->OutputDType("sampled_weight", 0) = ctx->InputDType("weight", 0); + *ctx->OutputDType("sampled_label", 0) = ctx->InputDType("label", 0); + return Maybe::Ok(); +} +/*static*/ Maybe DistributedPartialFcSampleOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); + CHECK_NOTNULL_OR_RETURN(label_modifier); + label_modifier->set_requires_grad(false); + return Maybe::Ok(); +} -REGISTER_USER_OP("distributed_partial_fc_sample_disable_boxing") - .Input("sampled_weight_diff") - .Input("sampled_label") - .Output("boxing_disabled_sampled_weight_diff") - .Output("boxing_disabled_sampled_label") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* boxing_disabled_sampled_weight_diff = - ctx->OutputTensorDesc("boxing_disabled_sampled_weight_diff", 0); - *boxing_disabled_sampled_weight_diff->mut_shape() = ctx->InputShape("sampled_weight_diff", 0); - CHECK_EQ_OR_RETURN(boxing_disabled_sampled_weight_diff->shape().At(0) % ctx->parallel_num(), - 0); - boxing_disabled_sampled_weight_diff->mut_shape()->Set( - 0, boxing_disabled_sampled_weight_diff->shape().At(0) / ctx->parallel_num()); - *boxing_disabled_sampled_weight_diff->mut_is_dynamic() = - ctx->InputIsDynamic("sampled_weight_diff", 0); - user_op::TensorDesc* boxing_disabled_sampled_label = - ctx->OutputTensorDesc("boxing_disabled_sampled_label", 0); - *boxing_disabled_sampled_label->mut_shape() = ctx->InputShape("sampled_label", 0); - CHECK_EQ_OR_RETURN(boxing_disabled_sampled_label->shape().At(0) % ctx->parallel_num(), 0); - boxing_disabled_sampled_label->mut_shape()->Set( - 0, boxing_disabled_sampled_label->shape().At(0) / ctx->parallel_num()); - *boxing_disabled_sampled_label->mut_is_dynamic() = ctx->InputIsDynamic("sampled_label", 0); - return Maybe::Ok(); - }) - .SetPhysicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("boxing_disabled_sampled_weight_diff", 0) = - ctx->InputShape("sampled_weight_diff", 0); - *ctx->OutputIsDynamic("boxing_disabled_sampled_weight_diff", 0) = - ctx->InputIsDynamic("sampled_weight_diff", 0); - *ctx->OutputShape("boxing_disabled_sampled_label", 0) = ctx->InputShape("sampled_label", 0); - *ctx->OutputIsDynamic("boxing_disabled_sampled_label", 0) = - ctx->InputIsDynamic("sampled_label", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("sampled_weight_diff", 0), 0) - .Split(user_op::OpArg("sampled_label", 0), 0) - .Broadcast(user_op::OpArg("boxing_disabled_sampled_weight_diff", 0)) - .Broadcast(user_op::OpArg("boxing_disabled_sampled_label", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("boxing_disabled_sampled_weight_diff", 0) = - ctx->InputDType("sampled_weight_diff", 0); - *ctx->OutputDType("boxing_disabled_sampled_label", 0) = ctx->InputDType("sampled_label", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe DistributedPartialFcSampleDisableBoxingOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("sampled_weight_diff", 0), 0) + .Split(user_op::OpArg("sampled_label", 0), 0) + .Broadcast(user_op::OpArg("boxing_disabled_sampled_weight_diff", 0)) + .Broadcast(user_op::OpArg("boxing_disabled_sampled_label", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe DistributedPartialFcSampleDisableBoxingOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + user_op::TensorDesc* boxing_disabled_sampled_weight_diff = + ctx->OutputTensorDesc("boxing_disabled_sampled_weight_diff", 0); + *boxing_disabled_sampled_weight_diff->mut_shape() = ctx->InputShape("sampled_weight_diff", 0); + CHECK_EQ_OR_RETURN(boxing_disabled_sampled_weight_diff->shape().At(0) % ctx->parallel_num(), 0); + boxing_disabled_sampled_weight_diff->mut_shape()->Set( + 0, boxing_disabled_sampled_weight_diff->shape().At(0) / ctx->parallel_num()); + *boxing_disabled_sampled_weight_diff->mut_is_dynamic() = + ctx->InputIsDynamic("sampled_weight_diff", 0); + user_op::TensorDesc* boxing_disabled_sampled_label = + ctx->OutputTensorDesc("boxing_disabled_sampled_label", 0); + *boxing_disabled_sampled_label->mut_shape() = ctx->InputShape("sampled_label", 0); + CHECK_EQ_OR_RETURN(boxing_disabled_sampled_label->shape().At(0) % ctx->parallel_num(), 0); + boxing_disabled_sampled_label->mut_shape()->Set( + 0, boxing_disabled_sampled_label->shape().At(0) / ctx->parallel_num()); + *boxing_disabled_sampled_label->mut_is_dynamic() = ctx->InputIsDynamic("sampled_label", 0); + return Maybe::Ok(); +} +/*static*/ Maybe DistributedPartialFcSampleDisableBoxingOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + *ctx->OutputShape("boxing_disabled_sampled_weight_diff", 0) = + ctx->InputShape("sampled_weight_diff", 0); + *ctx->OutputIsDynamic("boxing_disabled_sampled_weight_diff", 0) = + ctx->InputIsDynamic("sampled_weight_diff", 0); + *ctx->OutputShape("boxing_disabled_sampled_label", 0) = ctx->InputShape("sampled_label", 0); + *ctx->OutputIsDynamic("boxing_disabled_sampled_label", 0) = + ctx->InputIsDynamic("sampled_label", 0); + return Maybe::Ok(); +} +/*static*/ Maybe DistributedPartialFcSampleDisableBoxingOp::InferDataType( + user_op::InferContext* ctx) { + *ctx->OutputDType("boxing_disabled_sampled_weight_diff", 0) = + ctx->InputDType("sampled_weight_diff", 0); + *ctx->OutputDType("boxing_disabled_sampled_label", 0) = ctx->InputDType("sampled_label", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("distributed_partial_fc_sample") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { diff --git a/oneflow/user/ops/pool_op.cpp b/oneflow/user/ops/pool_op.cpp index 238c04e4c08..39afc8478b8 100644 --- a/oneflow/user/ops/pool_op.cpp +++ b/oneflow/user/ops/pool_op.cpp @@ -15,6 +15,7 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/utils/pool_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -113,49 +114,47 @@ GenBackwardOpConfFn MakeGenBackwardOpConfFn(const std::string& mode, const int32 } // namespace -#define REGISTER_TF_AVG_POOL_FORWARD_OP(name, dim) \ - REGISTER_USER_OP(name) \ - .Input("x") \ - .Output("y") \ - .Attr("padding") \ - .Attr>("padding_before") \ - .Attr>("padding_after") \ - .Attr("data_format") \ - .Attr>("pool_size") \ - .Attr>("strides") \ - .Attr("ceil_mode") \ - .SetTensorDescInferFn(MakeFwTensorDescInferFn(dim)) \ - .SetGetSbpFn(FwGetSbpFn) \ - .SetDataTypeInferFn(FwInferDataType); - -REGISTER_TF_AVG_POOL_FORWARD_OP("tf_avg_pool_1d", 1) -REGISTER_TF_AVG_POOL_FORWARD_OP("tf_avg_pool_2d", 2) -REGISTER_TF_AVG_POOL_FORWARD_OP("tf_avg_pool_3d", 3) - -#undef REGISTER_TF_AVG_POOL_FORWARD_OP - -#define REGISTER_TF_AVG_POOL_BACKWARD_OP(name) \ - REGISTER_USER_OP(name) \ - .Input("x") \ - .Input("y") \ - .Input("dy") \ - .Output("dx") \ - .Attr("padding") \ - .Attr>("padding_before") \ - .Attr>("padding_after") \ - .Attr("data_format") \ - .Attr>("pool_size") \ - .Attr>("strides") \ - .Attr("ceil_mode") \ - .SetTensorDescInferFn(BwTensorDescInferFn) \ - .SetGetSbpFn(BwGetSbpFn) \ - .SetDataTypeInferFn(BwInferDataType); - -REGISTER_TF_AVG_POOL_BACKWARD_OP("tf_avg_pool_1d_grad") -REGISTER_TF_AVG_POOL_BACKWARD_OP("tf_avg_pool_2d_grad") -REGISTER_TF_AVG_POOL_BACKWARD_OP("tf_avg_pool_3d_grad") - -#undef REGISTER_TF_AVG_POOL_FORWARD_OP +#define IMPLEMENT_TF_POOL_FUNCS(name, dim) \ + /*static*/ Maybe name##Op::GetSbp(user_op::SbpContext* ctx) { return FwGetSbpFn(ctx); } \ + /*static*/ Maybe name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return MakeFwTensorDescInferFn(dim)(ctx); \ + } \ + /*static*/ Maybe name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe name##Op::InferDataType(user_op::InferContext* ctx) { \ + return FwInferDataType(ctx); \ + } + +IMPLEMENT_TF_POOL_FUNCS(TfAvgPool1D, 1) +IMPLEMENT_TF_POOL_FUNCS(TfAvgPool2D, 2) +IMPLEMENT_TF_POOL_FUNCS(TfAvgPool3D, 3) +IMPLEMENT_TF_POOL_FUNCS(TfMaxPool1D, 1) +IMPLEMENT_TF_POOL_FUNCS(TfMaxPool2D, 2) +IMPLEMENT_TF_POOL_FUNCS(TfMaxPool3D, 3) +#undef IMPLEMENT_TF_POOL_FUNCS + +#define IMPLEMENT_TF_POOL_BACKWARD_FUNCS(name) \ + /*static*/ Maybe name##GradOp::GetSbp(user_op::SbpContext* ctx) { \ + return BwGetSbpFn(ctx); \ + } \ + /*static*/ Maybe name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return BwTensorDescInferFn(ctx); \ + } \ + /*static*/ Maybe name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe name##GradOp::InferDataType(user_op::InferContext* ctx) { \ + return BwInferDataType(ctx); \ + } + +IMPLEMENT_TF_POOL_BACKWARD_FUNCS(TfAvgPool1D) +IMPLEMENT_TF_POOL_BACKWARD_FUNCS(TfAvgPool2D) +IMPLEMENT_TF_POOL_BACKWARD_FUNCS(TfAvgPool3D) +IMPLEMENT_TF_POOL_BACKWARD_FUNCS(TfMaxPool1D) +IMPLEMENT_TF_POOL_BACKWARD_FUNCS(TfMaxPool2D) +IMPLEMENT_TF_POOL_BACKWARD_FUNCS(TfMaxPool3D) +#undef IMPLEMENT_TF_POOL_BACKWARD_FUNCS REGISTER_USER_OP_GRAD("tf_avg_pool_1d") .SetGenBackwardOpConfFn(MakeGenBackwardOpConfFn("tf_avg", 1)); @@ -164,50 +163,6 @@ REGISTER_USER_OP_GRAD("tf_avg_pool_2d") REGISTER_USER_OP_GRAD("tf_avg_pool_3d") .SetGenBackwardOpConfFn(MakeGenBackwardOpConfFn("tf_avg", 3)); -#define REGISTER_TF_MAX_POOL_FORWARD_OP(name, dim) \ - REGISTER_USER_OP(name) \ - .Input("x") \ - .Output("y") \ - .Attr("padding") \ - .Attr>("padding_before") \ - .Attr>("padding_after") \ - .Attr("data_format") \ - .Attr>("pool_size") \ - .Attr>("strides") \ - .Attr("ceil_mode") \ - .SetTensorDescInferFn(MakeFwTensorDescInferFn(dim)) \ - .SetGetSbpFn(FwGetSbpFn) \ - .SetDataTypeInferFn(FwInferDataType); - -REGISTER_TF_MAX_POOL_FORWARD_OP("tf_max_pool_1d", 1) -REGISTER_TF_MAX_POOL_FORWARD_OP("tf_max_pool_2d", 2) -REGISTER_TF_MAX_POOL_FORWARD_OP("tf_max_pool_3d", 3) - -#undef REGISTER_TF_MAX_POOL_FORWARD_OP - -#define REGISTER_TF_MAX_POOL_BACKWARD_OP(name) \ - REGISTER_USER_OP(name) \ - .Input("x") \ - .Input("y") \ - .Input("dy") \ - .Output("dx") \ - .Attr("padding") \ - .Attr>("padding_before") \ - .Attr>("padding_after") \ - .Attr("data_format") \ - .Attr>("pool_size") \ - .Attr>("strides") \ - .Attr("ceil_mode") \ - .SetTensorDescInferFn(BwTensorDescInferFn) \ - .SetGetSbpFn(BwGetSbpFn) \ - .SetDataTypeInferFn(BwInferDataType); - -REGISTER_TF_MAX_POOL_BACKWARD_OP("tf_max_pool_1d_grad") -REGISTER_TF_MAX_POOL_BACKWARD_OP("tf_max_pool_2d_grad") -REGISTER_TF_MAX_POOL_BACKWARD_OP("tf_max_pool_3d_grad") - -#undef REGISTER_TF_MAX_POOL_BACKWARD_OP - REGISTER_USER_OP_GRAD("tf_max_pool_1d") .SetGenBackwardOpConfFn(MakeGenBackwardOpConfFn("tf_max", 1)); REGISTER_USER_OP_GRAD("tf_max_pool_2d") diff --git a/oneflow/user/ops/pooling_op.cpp b/oneflow/user/ops/pooling_op.cpp index 9b44fb13a74..c46b85188d3 100644 --- a/oneflow/user/ops/pooling_op.cpp +++ b/oneflow/user/ops/pooling_op.cpp @@ -16,6 +16,7 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/pooling_kernel_util.h" #include "oneflow/user/kernels/avg_pooling_kernel_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -204,99 +205,85 @@ Maybe BwInferDataType(user_op::InferContext* ctx) { } } // namespace -#define REGISTER_MAXPOOL_FORWARD_OP(name, dim) \ - REGISTER_USER_OP(name) \ - .Input("x") \ - .Output("y") \ - .Output("indice") \ - .Attr>("padding") \ - .Attr("data_format") \ - .Attr>("kernel_size") \ - .Attr>("stride") \ - .Attr>("dilation") \ - .Attr("return_indices") \ - .Attr("ceil_mode") \ - .SetTensorDescInferFn(MaxPoolMakeForwardTensorDescInferFn(dim)) \ - .SetGetSbpFn(MaxPoolForwardGetSbpFn) \ - .SetDataTypeInferFn(FwInferDataType); - -REGISTER_MAXPOOL_FORWARD_OP("maxpool_1d", 1) -REGISTER_MAXPOOL_FORWARD_OP("maxpool_2d", 2) -REGISTER_MAXPOOL_FORWARD_OP("maxpool_3d", 3) - -#undef REGISTER_MAXPOOL_FORWARD_OP - -#define REGISTER_MAXPOOL_BACKWARD_OP(name) \ - REGISTER_USER_OP(name) \ - .Input("x") \ - .Input("y") \ - .Input("indice") \ - .Input("dy") \ - .Output("dx") \ - .Attr>("padding") \ - .Attr("data_format") \ - .Attr>("kernel_size") \ - .Attr>("stride") \ - .Attr>("dilation") \ - .Attr("return_indices") \ - .Attr("ceil_mode") \ - .SetTensorDescInferFn(BackwardTensorDescInferFn) \ - .SetGetSbpFn(MaxPoolBackwardGetSbpFn) \ - .SetDataTypeInferFn(BwInferDataType); - -REGISTER_MAXPOOL_BACKWARD_OP("maxpool_1d_grad") -REGISTER_MAXPOOL_BACKWARD_OP("maxpool_2d_grad") -REGISTER_MAXPOOL_BACKWARD_OP("maxpool_3d_grad") - -#undef REGISTER_MAXPOOL_BACKWARD_OP +#define IMPLEMENT_MAXPOOL_FUNCS(name, dim) \ + /*static*/ Maybe name##Op::GetSbp(user_op::SbpContext* ctx) { \ + return MaxPoolForwardGetSbpFn(ctx); \ + } \ + /*static*/ Maybe name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return MaxPoolMakeForwardTensorDescInferFn(dim)(ctx); \ + } \ + /*static*/ Maybe name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe name##Op::InferDataType(user_op::InferContext* ctx) { \ + return FwInferDataType(ctx); \ + } + +IMPLEMENT_MAXPOOL_FUNCS(MaxPool1D, 1) +IMPLEMENT_MAXPOOL_FUNCS(MaxPool2D, 2) +IMPLEMENT_MAXPOOL_FUNCS(MaxPool3D, 3) +#undef IMPLEMENT_MAXPOOL_FUNCS + +#define IMPLEMENT_MAXPOOL_BACKWARD_FUNCS(name) \ + /*static*/ Maybe name##GradOp::GetSbp(user_op::SbpContext* ctx) { \ + return MaxPoolBackwardGetSbpFn(ctx); \ + } \ + /*static*/ Maybe name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return BackwardTensorDescInferFn(ctx); \ + } \ + /*static*/ Maybe name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe name##GradOp::InferDataType(user_op::InferContext* ctx) { \ + return BwInferDataType(ctx); \ + } + +IMPLEMENT_MAXPOOL_BACKWARD_FUNCS(MaxPool1D) +IMPLEMENT_MAXPOOL_BACKWARD_FUNCS(MaxPool2D) +IMPLEMENT_MAXPOOL_BACKWARD_FUNCS(MaxPool3D) +#undef IMPLEMENT_MAXPOOL_BACKWARD_FUNCS REGISTER_USER_OP_GRAD("maxpool_1d").SetGenBackwardOpConfFn(MaxPoolMakeBackwardOpConfFn("max", 1)); REGISTER_USER_OP_GRAD("maxpool_2d").SetGenBackwardOpConfFn(MaxPoolMakeBackwardOpConfFn("max", 2)); REGISTER_USER_OP_GRAD("maxpool_3d").SetGenBackwardOpConfFn(MaxPoolMakeBackwardOpConfFn("max", 3)); -#define REGISTER_AVGPOOL_FORWARD_OP(name, ndim) \ - REGISTER_USER_OP(name) \ - .Input("x") \ - .Output("y") \ - .Attr>("padding") \ - .Attr("data_format") \ - .Attr>("kernel_size") \ - .Attr>("stride") \ - .Attr("ceil_mode") \ - .Attr("count_include_pad") \ - .Attr("divisor_override") \ - .SetTensorDescInferFn(AvgPoolMakeForwardTensorDescInferFn(ndim)) \ - .SetGetSbpFn(AvgPoolForwardGetSbpFn) \ - .SetDataTypeInferFn(FwInferDataType); - -REGISTER_AVGPOOL_FORWARD_OP("avgpool_1d", 1); -REGISTER_AVGPOOL_FORWARD_OP("avgpool_2d", 2); -REGISTER_AVGPOOL_FORWARD_OP("avgpool_3d", 3); - -#undef REGISTER_AVGPOOL_FORWARD_OP - -#define REGISTER_AVGPOOL_BACKWARD_OP(name) \ - REGISTER_USER_OP(name) \ - .Input("x") \ - .Input("y") \ - .Input("dy") \ - .Output("dx") \ - .Attr>("padding") \ - .Attr("data_format") \ - .Attr>("kernel_size") \ - .Attr>("stride") \ - .Attr("ceil_mode") \ - .Attr("count_include_pad") \ - .Attr("divisor_override") \ - .SetTensorDescInferFn(BackwardTensorDescInferFn) \ - .SetGetSbpFn(AvgPoolBackwardGetSbpFn) \ - .SetDataTypeInferFn(BwInferDataType); - -REGISTER_AVGPOOL_BACKWARD_OP("avgpool_1d_grad"); -REGISTER_AVGPOOL_BACKWARD_OP("avgpool_2d_grad"); -REGISTER_AVGPOOL_BACKWARD_OP("avgpool_3d_grad"); - -#undef REGISTER_AVGPOOL_BACKWARD_OP +#define IMPLEMENT_AVGPOOL_FUNCS(name, ndim) \ + /*static*/ Maybe name##Op::GetSbp(user_op::SbpContext* ctx) { \ + return AvgPoolForwardGetSbpFn(ctx); \ + } \ + /*static*/ Maybe name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return AvgPoolMakeForwardTensorDescInferFn(ndim)(ctx); \ + } \ + /*static*/ Maybe name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe name##Op::InferDataType(user_op::InferContext* ctx) { \ + return FwInferDataType(ctx); \ + } + +IMPLEMENT_AVGPOOL_FUNCS(AvgPool1D, 1) +IMPLEMENT_AVGPOOL_FUNCS(AvgPool2D, 2) +IMPLEMENT_AVGPOOL_FUNCS(AvgPool3D, 3) +#undef IMPLEMENT_AVGPOOL_FUNCS + +#define IMPLEMENT_AVGPOOL_BACKWARD_FUNCS(name) \ + /*static*/ Maybe name##GradOp::GetSbp(user_op::SbpContext* ctx) { \ + return AvgPoolBackwardGetSbpFn(ctx); \ + } \ + /*static*/ Maybe name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return BackwardTensorDescInferFn(ctx); \ + } \ + /*static*/ Maybe name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe name##GradOp::InferDataType(user_op::InferContext* ctx) { \ + return BwInferDataType(ctx); \ + } + +IMPLEMENT_AVGPOOL_BACKWARD_FUNCS(AvgPool1D) +IMPLEMENT_AVGPOOL_BACKWARD_FUNCS(AvgPool2D) +IMPLEMENT_AVGPOOL_BACKWARD_FUNCS(AvgPool3D) +#undef IMPLEMENT_AVGPOOL_BACKWARD_FUNCS REGISTER_USER_OP_GRAD("avgpool_1d").SetGenBackwardOpConfFn(AvgPoolMakeBackwardOpConfFn(1)); REGISTER_USER_OP_GRAD("avgpool_2d").SetGenBackwardOpConfFn(AvgPoolMakeBackwardOpConfFn(2)); diff --git a/oneflow/user/ops/prelu_op.cpp b/oneflow/user/ops/prelu_op.cpp index c100cbe1b59..c6104318156 100644 --- a/oneflow/user/ops/prelu_op.cpp +++ b/oneflow/user/ops/prelu_op.cpp @@ -14,105 +14,101 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("prelu") - .Input("x") - .Input("alpha") - .Output("y") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - Shape* y_shape = ctx->OutputShape("y", 0); - const Shape& alpha_shape = ctx->InputShape("alpha", 0); - CHECK_EQ_OR_RETURN(alpha_shape.NumAxes(), 1); - *y_shape = x_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - const user_op::TensorDesc& alpha_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("alpha", 0); - if (alpha_tensor.shape().At(0) != 1) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), 1) - .Split(user_op::OpArg("alpha", 0), 0) - .Split(user_op::OpArg("y", 0), 1) - .Build(); - } - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - if (i == 1) continue; - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Broadcast(user_op::OpArg("alpha", 0)) - .Split(user_op::OpArg("y", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe PreluOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + const user_op::TensorDesc& alpha_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("alpha", 0); + if (alpha_tensor.shape().At(0) != 1) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), 1) + .Split(user_op::OpArg("alpha", 0), 0) + .Split(user_op::OpArg("y", 0), 1) + .Build(); + } + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + if (i == 1) continue; + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Broadcast(user_op::OpArg("alpha", 0)) + .Split(user_op::OpArg("y", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe PreluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + Shape* y_shape = ctx->OutputShape("y", 0); + const Shape& alpha_shape = ctx->InputShape("alpha", 0); + CHECK_EQ_OR_RETURN(alpha_shape.NumAxes(), 1); + *y_shape = x_shape; + return Maybe::Ok(); +} +/*static*/ Maybe PreluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe PreluOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("prelu_grad") - .Input("dy") - .Input("x") - .Input("alpha") - .Output("dx") - .Output("alpha_diff") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - Shape* alpha_diff_shape = ctx->OutputShape("alpha_diff", 0); - const Shape& alpha_shape = ctx->InputShape("alpha", 0); - CHECK_EQ_OR_RETURN(alpha_shape.NumAxes(), 1); - CHECK_OR_RETURN((alpha_shape.At(0) == x_shape.At(1)) || (alpha_shape.At(0) == 1)); - CHECK_EQ_OR_RETURN(dy_shape, x_shape); - *dx_shape = x_shape; - *alpha_diff_shape = alpha_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), 0) - .Split(user_op::OpArg("x", 0), 0) - .Broadcast(user_op::OpArg("alpha", 0)) - .Split(user_op::OpArg("dx", 0), 0) - .PartialSum(user_op::OpArg("alpha_diff", 0)) - .Build(); - ctx->NewBuilder() - .PartialSum(user_op::OpArg("dy", 0)) - .Broadcast(user_op::OpArg("x", 0)) - .Broadcast(user_op::OpArg("alpha", 0)) - .PartialSum(user_op::OpArg("dx", 0)) - .PartialSum(user_op::OpArg("alpha_diff", 0)) - .Build(); - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), 1) - .Split(user_op::OpArg("x", 0), 1) - .Split(user_op::OpArg("alpha", 0), 0) - .Split(user_op::OpArg("dx", 0), 1) - .Split(user_op::OpArg("alpha_diff", 0), 0) - .Build(); - FOR_RANGE(int64_t, i, 1, x_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("alpha", 0), 0) - .Split(user_op::OpArg("dx", 0), i) - .Split(user_op::OpArg("alpha_diff", 0), 0) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - *ctx->OutputDType("alpha_diff", 0) = ctx->InputDType("alpha", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe PreluGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), 0) + .Split(user_op::OpArg("x", 0), 0) + .Broadcast(user_op::OpArg("alpha", 0)) + .Split(user_op::OpArg("dx", 0), 0) + .PartialSum(user_op::OpArg("alpha_diff", 0)) + .Build(); + ctx->NewBuilder() + .PartialSum(user_op::OpArg("dy", 0)) + .Broadcast(user_op::OpArg("x", 0)) + .Broadcast(user_op::OpArg("alpha", 0)) + .PartialSum(user_op::OpArg("dx", 0)) + .PartialSum(user_op::OpArg("alpha_diff", 0)) + .Build(); + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), 1) + .Split(user_op::OpArg("x", 0), 1) + .Split(user_op::OpArg("alpha", 0), 0) + .Split(user_op::OpArg("dx", 0), 1) + .Split(user_op::OpArg("alpha_diff", 0), 0) + .Build(); + FOR_RANGE(int64_t, i, 1, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("alpha", 0), 0) + .Split(user_op::OpArg("dx", 0), i) + .Split(user_op::OpArg("alpha_diff", 0), 0) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe PreluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + Shape* alpha_diff_shape = ctx->OutputShape("alpha_diff", 0); + const Shape& alpha_shape = ctx->InputShape("alpha", 0); + CHECK_EQ_OR_RETURN(alpha_shape.NumAxes(), 1); + CHECK_OR_RETURN((alpha_shape.At(0) == x_shape.At(1)) || (alpha_shape.At(0) == 1)); + CHECK_EQ_OR_RETURN(dy_shape, x_shape); + *dx_shape = x_shape; + *alpha_diff_shape = alpha_shape; + return Maybe::Ok(); +} +/*static*/ Maybe PreluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe PreluGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->OutputDType("alpha_diff", 0) = ctx->InputDType("alpha", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("prelu").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/quantization_op.cpp b/oneflow/user/ops/quantization_op.cpp index b67e707113e..2396a1a1685 100644 --- a/oneflow/user/ops/quantization_op.cpp +++ b/oneflow/user/ops/quantization_op.cpp @@ -14,105 +14,93 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ Maybe QuantizationOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + const Shape& logical_scale_shape = + ctx->LogicalTensorDesc4InputArgNameAndIndex("scale", 0).shape(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("in", 0)) + .Broadcast(user_op::OpArg("scale", 0)) + .Broadcast(user_op::OpArg("zero_point", 0)) + .Broadcast(user_op::OpArg("out", 0)) + .Build(); + if (logical_scale_shape.elem_cnt() > 1) { + // NOTE(Liang Depeng): only consider convolution weight per-channel quantization + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), 0) + .Split(user_op::OpArg("scale", 0), 0) + .Split(user_op::OpArg("zero_point", 0), 0) + .Split(user_op::OpArg("out", 0), 0) + .Build(); + } else { + // NOTE(Liang Depeng): the sbp signature of per-layer quantization is the same as eltwise + // ops + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), 0) + .Broadcast(user_op::OpArg("scale", 0)) + .Broadcast(user_op::OpArg("zero_point", 0)) + .Split(user_op::OpArg("out", 0), 0) + .Build(); + } + FOR_RANGE(int64_t, i, 1, in_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .Broadcast(user_op::OpArg("scale", 0)) + .Broadcast(user_op::OpArg("zero_point", 0)) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe QuantizationOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + const Shape& scale_shape = ctx->InputShape("scale", 0); + const Shape& zero_point_shape = ctx->InputShape("zero_point", 0); -REGISTER_USER_OP("quantization") - .Input("in") - .Input("scale") - .Input("zero_point") - .Output("out") - // NOTE(Liang Depeng): "google" or "cambricon" - .Attr("quantization_formula", "google") - // NOTE(Liang Depeng): quantize from float32 to "quantization_bit" bit signed or unsigned - // integer - .Attr("quantization_bit", 8) - // NOTE(Liang Depeng): "symmetric" or "affine": quantize to signed or unsigned integer - .Attr("quantization_scheme", "symmetric") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - const Shape& scale_shape = ctx->InputShape("scale", 0); - const Shape& zero_point_shape = ctx->InputShape("zero_point", 0); + // NOTE(Liang Depeng): scale_shape->elem_cnt() > 1 means per-channel quantization for + // convolution weights. + if (scale_shape.elem_cnt() > 1) { + CHECK_EQ_OR_RETURN(scale_shape.elem_cnt(), in_shape.At(0)); + CHECK_EQ_OR_RETURN(zero_point_shape.elem_cnt(), in_shape.At(0)); + } - // NOTE(Liang Depeng): scale_shape->elem_cnt() > 1 means per-channel quantization for - // convolution weights. - if (scale_shape.elem_cnt() > 1) { - CHECK_EQ_OR_RETURN(scale_shape.elem_cnt(), in_shape.At(0)); - CHECK_EQ_OR_RETURN(zero_point_shape.elem_cnt(), in_shape.At(0)); - } + *ctx->OutputShape("out", 0) = in_shape; + return Maybe::Ok(); +} +/*static*/ Maybe QuantizationOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe QuantizationOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe QuantizationOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* scale = GetInputArgModifierFn("scale", 0); + CHECK_OR_RETURN(scale != nullptr); + scale->set_requires_grad(false); - *ctx->OutputShape("out", 0) = in_shape; - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* scale = GetInputArgModifierFn("scale", 0); - CHECK_OR_RETURN(scale != nullptr); - scale->set_requires_grad(false); + user_op::InputArgModifier* zero_point = GetInputArgModifierFn("zero_point", 0); + CHECK_OR_RETURN(zero_point != nullptr); + zero_point->set_requires_grad(false); + return Maybe::Ok(); +} +/*static*/ Maybe QuantizationOp::CheckAttr(const user_op::UserOpDefWrapper&, + const user_op::UserOpConfWrapper& op_conf) { + const int32_t quantization_bit = op_conf.attr("quantization_bit"); + CHECK_GT_OR_RETURN(quantization_bit, 1); + CHECK_LE_OR_RETURN(quantization_bit, 8); - user_op::InputArgModifier* zero_point = GetInputArgModifierFn("zero_point", 0); - CHECK_OR_RETURN(zero_point != nullptr); - zero_point->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - const Shape& logical_scale_shape = - ctx->LogicalTensorDesc4InputArgNameAndIndex("scale", 0).shape(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("in", 0)) - .Broadcast(user_op::OpArg("scale", 0)) - .Broadcast(user_op::OpArg("zero_point", 0)) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - if (logical_scale_shape.elem_cnt() > 1) { - // NOTE(Liang Depeng): only consider convolution weight per-channel quantization - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), 0) - .Split(user_op::OpArg("scale", 0), 0) - .Split(user_op::OpArg("zero_point", 0), 0) - .Split(user_op::OpArg("out", 0), 0) - .Build(); - } else { - // NOTE(Liang Depeng): the sbp signature of per-layer quantization is the same as eltwise - // ops - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), 0) - .Broadcast(user_op::OpArg("scale", 0)) - .Broadcast(user_op::OpArg("zero_point", 0)) - .Split(user_op::OpArg("out", 0), 0) - .Build(); - } - FOR_RANGE(int64_t, i, 1, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Broadcast(user_op::OpArg("scale", 0)) - .Broadcast(user_op::OpArg("zero_point", 0)) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& op_def, - const user_op::UserOpConfWrapper& op_conf) -> Maybe { - const int32_t quantization_bit = op_conf.attr("quantization_bit"); - CHECK_GT_OR_RETURN(quantization_bit, 1); - CHECK_LE_OR_RETURN(quantization_bit, 8); + std::string quantization_scheme = op_conf.attr("quantization_scheme"); + CHECK_OR_RETURN(quantization_scheme == "symmetric" || quantization_scheme == "affine"); - std::string quantization_scheme = op_conf.attr("quantization_scheme"); - CHECK_OR_RETURN(quantization_scheme == "symmetric" || quantization_scheme == "affine"); - - std::string quantization_formula = op_conf.attr("quantization_formula"); - CHECK_OR_RETURN(quantization_formula == "google" || quantization_formula == "cambricon"); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); - -} // namespace + std::string quantization_formula = op_conf.attr("quantization_formula"); + CHECK_OR_RETURN(quantization_formula == "google" || quantization_formula == "cambricon"); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/randperm_op.cpp b/oneflow/user/ops/randperm_op.cpp index 7ce336d4b44..ae52c9b2938 100644 --- a/oneflow/user/ops/randperm_op.cpp +++ b/oneflow/user/ops/randperm_op.cpp @@ -14,34 +14,29 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" -#include "oneflow/core/common/global.h" -#include "oneflow/core/common/multi_client.h" -#include "oneflow/core/common/protobuf.h" -#include "oneflow/core/job/global_for.h" +#include "oneflow/core/framework/op_generated.h" + namespace oneflow { -Maybe InferRandpermNdSbp(user_op::InferNdSbpFnContext* ctx); -REGISTER_NO_GRAD_USER_OP("randperm") - .Output("out") - .Attr("n") - .Attr("seed") - .Attr>("nd_sbp") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - Shape* out_shape = ctx->OutputShape("out", 0); - int32_t n = ctx->Attr("n"); - CHECK_GE_OR_RETURN(n, 0); - *out_shape = Shape({n}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { return Maybe::Ok(); }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = DataType::kInt32; - return Maybe::Ok(); - }) - .SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe { - cfg::SbpParallel default_sbp; - default_sbp.mutable_broadcast_parallel(); - return user_op::InferNdSbp4SrcOp(ctx, default_sbp); - }); +/*static*/ Maybe RandpermOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { + cfg::SbpParallel default_sbp; + default_sbp.mutable_broadcast_parallel(); + return user_op::InferNdSbp4SrcOp(ctx, default_sbp); +} +/*static*/ Maybe RandpermOp::GetSbp(user_op::SbpContext* ctx) { return Maybe::Ok(); } +/*static*/ Maybe RandpermOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + Shape* out_shape = ctx->OutputShape("out", 0); + int32_t n = ctx->Attr("n"); + CHECK_GE_OR_RETURN(n, 0); + *out_shape = Shape({n}); + return Maybe::Ok(); +} +/*static*/ Maybe RandpermOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe RandpermOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = DataType::kInt32; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/reduce_like_ops.cpp b/oneflow/user/ops/reduce_like_ops.cpp index c8fc889debc..898e81ade46 100644 --- a/oneflow/user/ops/reduce_like_ops.cpp +++ b/oneflow/user/ops/reduce_like_ops.cpp @@ -15,81 +15,80 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/reduce_sbp_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("reduce_sum_like") - .Input("x") - .Input("like") - .Output("y") - .Attr>("axis") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& like_tensor = ctx->InputTensorDesc("like", 0); - const auto& axis = ctx->Attr>("axis"); - if (axis.empty()) { CHECK_EQ_OR_RETURN(x_tensor.shape(), like_tensor.shape()); } - user_op::TensorDesc* y_tensor = ctx->OutputTensorDesc("y", 0); - *y_tensor->mut_shape() = like_tensor.shape(); - *y_tensor->mut_is_dynamic() = like_tensor.is_dynamic(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - int32_t num_axes = 0; - HashSet conf_axes; - { - const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - num_axes = in_tensor.shape().NumAxes(); - const auto& reduced_axes = ctx->Attr>("axis"); - ReduceSbpUtil::GetRegularAxes(num_axes, reduced_axes, &conf_axes); - } - const auto& like_num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape().NumAxes(); - const bool keep_dims = (num_axes == like_num_axes); - if (!keep_dims) { CHECK_EQ_OR_RETURN(conf_axes.size(), num_axes - like_num_axes); } - auto IsReducedAxis = ReduceSbpUtil::MakePredicatorIsReducedAxis(conf_axes, num_axes); - int64_t num_reduced_axes = 0; - FOR_RANGE(int64_t, i, 0, num_axes) { - if (IsReducedAxis(i)) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Broadcast(user_op::OpArg("like", 0)) - .PartialSum(user_op::OpArg("y", 0)) - .Build(); - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .PartialSum(user_op::OpArg("like", 0)) - .PartialSum(user_op::OpArg("y", 0)) - .Build(); - num_reduced_axes += 1; - } else { - const int64_t out_split_axis = keep_dims ? i : i - num_reduced_axes; - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("like", 0), out_split_axis) - .Split(user_op::OpArg("y", 0), out_split_axis) - .Build(); - } - } +/*static*/ Maybe ReduceSumLikeOp::GetSbp(user_op::SbpContext* ctx) { + int32_t num_axes = 0; + HashSet conf_axes; + { + const auto& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + num_axes = in_tensor.shape().NumAxes(); + const auto& reduced_axes = ctx->Attr>("axis"); + ReduceSbpUtil::GetRegularAxes(num_axes, reduced_axes, &conf_axes); + } + const auto& like_num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape().NumAxes(); + const bool keep_dims = (num_axes == like_num_axes); + if (!keep_dims) { CHECK_EQ_OR_RETURN(conf_axes.size(), num_axes - like_num_axes); } + auto IsReducedAxis = ReduceSbpUtil::MakePredicatorIsReducedAxis(conf_axes, num_axes); + int64_t num_reduced_axes = 0; + FOR_RANGE(int64_t, i, 0, num_axes) { + if (IsReducedAxis(i)) { ctx->NewBuilder() - .Broadcast(user_op::OpArg("x", 0)) + .Split(user_op::OpArg("x", 0), i) + .Broadcast(user_op::OpArg("like", 0)) + .PartialSum(user_op::OpArg("y", 0)) + .Build(); + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) .PartialSum(user_op::OpArg("like", 0)) - .Broadcast(user_op::OpArg("y", 0)) + .PartialSum(user_op::OpArg("y", 0)) + .Build(); + num_reduced_axes += 1; + } else { + const int64_t out_split_axis = keep_dims ? i : i - num_reduced_axes; + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("like", 0), out_split_axis) + .Split(user_op::OpArg("y", 0), out_split_axis) .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& like_tensor = ctx->InputTensorDesc("like", 0); - CHECK_EQ_OR_RETURN(x_tensor.data_type(), like_tensor.data_type()); - *ctx->OutputDType("y", 0) = like_tensor.data_type(); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* like_arg_modifier = GetInputArgModifierFn("like", 0); - CHECK_OR_RETURN(like_arg_modifier != nullptr); - like_arg_modifier->set_requires_grad(false); - return Maybe::Ok(); - }); + } + } + ctx->NewBuilder() + .Broadcast(user_op::OpArg("x", 0)) + .PartialSum(user_op::OpArg("like", 0)) + .Broadcast(user_op::OpArg("y", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe ReduceSumLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& like_tensor = ctx->InputTensorDesc("like", 0); + const auto& axis = ctx->Attr>("axis"); + if (axis.empty()) { CHECK_EQ_OR_RETURN(x_tensor.shape(), like_tensor.shape()); } + user_op::TensorDesc* y_tensor = ctx->OutputTensorDesc("y", 0); + *y_tensor->mut_shape() = like_tensor.shape(); + *y_tensor->mut_is_dynamic() = like_tensor.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ Maybe ReduceSumLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ReduceSumLikeOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& like_tensor = ctx->InputTensorDesc("like", 0); + CHECK_EQ_OR_RETURN(x_tensor.data_type(), like_tensor.data_type()); + *ctx->OutputDType("y", 0) = like_tensor.data_type(); + return Maybe::Ok(); +} +/*static*/ Maybe ReduceSumLikeOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* like_arg_modifier = GetInputArgModifierFn("like", 0); + CHECK_OR_RETURN(like_arg_modifier != nullptr); + like_arg_modifier->set_requires_grad(false); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/reduce_ops.cpp b/oneflow/user/ops/reduce_ops.cpp index 68abea62e7d..610c01a5aaf 100644 --- a/oneflow/user/ops/reduce_ops.cpp +++ b/oneflow/user/ops/reduce_ops.cpp @@ -16,6 +16,7 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/reduce_sbp_util.h" #include "oneflow/core/ndarray/binary_func.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -83,32 +84,27 @@ Maybe GetSbpFn(user_op::SbpContext* ctx) { return Maybe::Ok(); } -#define REGISTER_REDUCE_USER_OP(op_name, binary_func) \ - REGISTER_USER_OP(op_name) \ - .Input("input_tensor") \ - .Output("output_tensor") \ - .Attr>("axis") \ - .Attr("keepdims") \ - .SetTensorDescInferFn(InferTensorDescFn) \ - .SetGetSbpFn(GetSbpFn) \ - .SetDataTypeInferFn(InferDataType); - -#define REGISTER_REDUCE_LOGICAL_USER_OP(op_name, binary_func) \ - REGISTER_USER_OP(op_name) \ - .Input("input_tensor") \ - .Output("output_tensor") \ - .Attr>("axis") \ - .Attr("keepdims") \ - .SetTensorDescInferFn(InferTensorDescFn) \ - .SetGetSbpFn(GetSbpFn) \ - .SetDataTypeInferFn(InferLogicalDataType); - -REGISTER_REDUCE_LOGICAL_USER_OP("reduce_any", BinaryFuncAny) -REGISTER_REDUCE_LOGICAL_USER_OP("reduce_all", BinaryFuncAll) -REGISTER_REDUCE_USER_OP("reduce_min", BinaryFuncMin) -REGISTER_REDUCE_USER_OP("reduce_prod", BinaryFuncProd) -REGISTER_REDUCE_USER_OP("reduce_sum", BinaryFuncSum) -REGISTER_REDUCE_USER_OP("reduce_max", BinaryFuncMax) +#define IMPLEMENT_REDUCE_OP_FUNCS(name, binary_func, infer_dtype_func) \ + /*static*/ Maybe name##Op::GetSbp(user_op::SbpContext* ctx) { \ + return GetSbpFn(ctx); \ + } \ + /*static*/ Maybe name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return InferTensorDescFn(ctx); \ + } \ + /*static*/ Maybe name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe name##Op::InferDataType(user_op::InferContext* ctx) { \ + return infer_dtype_func(ctx); \ + } + +IMPLEMENT_REDUCE_OP_FUNCS(ReduceAny, BinaryFuncAny, InferLogicalDataType) +IMPLEMENT_REDUCE_OP_FUNCS(ReduceAll, BinaryFuncAll, InferLogicalDataType) +IMPLEMENT_REDUCE_OP_FUNCS(ReduceMin, BinaryFuncMin, oneflow::InferDataType) +IMPLEMENT_REDUCE_OP_FUNCS(ReduceMax, BinaryFuncMax, oneflow::InferDataType) +IMPLEMENT_REDUCE_OP_FUNCS(ReduceSum, BinaryFuncSum, oneflow::InferDataType) +IMPLEMENT_REDUCE_OP_FUNCS(ReduceProd, BinaryFuncProd, oneflow::InferDataType) +#undef IMPLEMENT_REDUCE_OP_FUNCS REGISTER_USER_OP_GRAD("reduce_sum") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/relu_op.cpp b/oneflow/user/ops/relu_op.cpp index d2ae5b6bf23..52fb55fdc22 100644 --- a/oneflow/user/ops/relu_op.cpp +++ b/oneflow/user/ops/relu_op.cpp @@ -14,76 +14,73 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ Maybe ReluOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe ReluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("x", 0); + Shape* out_shape = ctx->OutputShape("y", 0); + *out_shape = in_shape; + return Maybe::Ok(); +} +/*static*/ Maybe ReluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ReluOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("relu") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); - *out_shape = in_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ReluGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); + FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("y", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe ReluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& y_shape = ctx->InputShape("y", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(dy_shape == y_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} +/*static*/ Maybe ReluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ReluGradOp::InferDataType(user_op::InferContext* ctx) { + const DataType& data_type = ctx->InputDType("y", 0); + CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), data_type); + *ctx->OutputDType("dx", 0) = data_type; + return Maybe::Ok(); +} -REGISTER_USER_OP("relu_grad") - .Input("y") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& y_shape = ctx->InputShape("y", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(dy_shape == y_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); - FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("y", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const DataType& data_type = ctx->InputDType("y", 0); - CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), data_type); - *ctx->OutputDType("dx", 0) = data_type; - return Maybe::Ok(); - }); +namespace { REGISTER_USER_OP_GRAD("relu").SetBackwardOpConfGenFn( [](user_op::BackwardOpConfContext* ctx) -> Maybe { const auto relu_grad_op_name = ctx->FwOp().op_name() + "_grad"; ctx->DefineOp(relu_grad_op_name, [&ctx](user_op::BackwardOpBuilder& builder) { return builder.OpTypeName("relu_grad") - .InputBind("y", ctx->FwOp().output("out", 0)) - .InputBind("dy", ctx->FwOp().output_grad("out", 0)) + .InputBind("y", ctx->FwOp().output("y", 0)) + .InputBind("dy", ctx->FwOp().output_grad("y", 0)) .Output("dx") .Build(); }); - ctx->FwOp().InputGradBind(user_op::OpArg("in", 0), + ctx->FwOp().InputGradBind(user_op::OpArg("x", 0), [&ctx, &relu_grad_op_name]() -> const std::string& { return ctx->GetOp(relu_grad_op_name).output("dx", 0); }); diff --git a/oneflow/user/ops/repeat_op.cpp b/oneflow/user/ops/repeat_op.cpp index 4098ac9d00e..2b087308603 100644 --- a/oneflow/user/ops/repeat_op.cpp +++ b/oneflow/user/ops/repeat_op.cpp @@ -14,45 +14,42 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ Maybe RepeatOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe RepeatOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe RepeatOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe RepeatOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe RepeatOp::InferOutputBlobTimeShape( + user_op::InferOutputBlobTimeShapeFnContext* ctx) { + DimVector dim_vec(ctx->TimeShape4InputArgNameAndIndex("in", 0).dim_vec()); + dim_vec.emplace_back(ctx->user_op_conf().attr("repeat_num")); + *ctx->mut_output_blob_time_shape() = Shape(dim_vec); + return Maybe::Ok(); +} -REGISTER_USER_OP("repeat") - .Input("in") - .Output("out") - .Attr("repeat_num") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetOutputBlobTimeShapeInferFn( - [](user_op::InferOutputBlobTimeShapeFnContext* ctx) -> Maybe { - DimVector dim_vec(ctx->TimeShape4InputArgNameAndIndex("in", 0).dim_vec()); - dim_vec.emplace_back(ctx->user_op_conf().attr("repeat_num")); - *ctx->mut_output_blob_time_shape() = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +namespace { REGISTER_USER_OP_GRAD("repeat").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { diff --git a/oneflow/user/ops/reshape_like_op.cpp b/oneflow/user/ops/reshape_like_op.cpp index af758ec3414..3c6b3d720fa 100644 --- a/oneflow/user/ops/reshape_like_op.cpp +++ b/oneflow/user/ops/reshape_like_op.cpp @@ -15,60 +15,53 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/reshape_user_op_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -Maybe InferNdSbpFn(user_op::InferNdSbpFnContext* ctx) { +/*static*/ Maybe ReshapeLikeOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { const Shape& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); const Shape& out_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape(); return ReshapeUserOpUtil::InferNdSbp(ctx, in_shape, out_shape); } - -} // namespace - -REGISTER_USER_OP("reshape_like") - .Input("in") - .Input("like") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - const Shape& like_shape = ctx->InputShape("like", 0); - CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), like_shape.elem_cnt()); - *ctx->OutputShape("out", 0) = like_shape; - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); - CHECK_NOTNULL_OR_RETURN(like_modifier); - like_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); - const auto& like_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape(); - ctx->NewBuilder() - .PartialSum(user_op::OpArg("like", 0)) - .Broadcast(user_op::OpArg("in", 0)) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("like", 0)) - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - user_op::UserOpSbpSignatureBuilder builder = ctx->NewBuilder(); - return ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures(in_shape, like_shape, {{"in", 0}}, - {{"like", 0}, {"out", 0}}, - ctx->parallel_num(), &builder); - }) - .SetNdSbpInferFn(InferNdSbpFn) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ReshapeLikeOp::GetSbp(user_op::SbpContext* ctx) { + const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); + const auto& like_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape(); + ctx->NewBuilder() + .PartialSum(user_op::OpArg("like", 0)) + .Broadcast(user_op::OpArg("in", 0)) + .Broadcast(user_op::OpArg("out", 0)) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("like", 0)) + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + user_op::UserOpSbpSignatureBuilder builder = ctx->NewBuilder(); + return ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures( + in_shape, like_shape, {{"in", 0}}, {{"like", 0}, {"out", 0}}, ctx->parallel_num(), &builder); +} +/*static*/ Maybe ReshapeLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + const Shape& like_shape = ctx->InputShape("like", 0); + CHECK_EQ_OR_RETURN(in_shape.elem_cnt(), like_shape.elem_cnt()); + *ctx->OutputShape("out", 0) = like_shape; + return Maybe::Ok(); +} +/*static*/ Maybe ReshapeLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ReshapeLikeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe ReshapeLikeOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); + CHECK_NOTNULL_OR_RETURN(like_modifier); + like_modifier->set_requires_grad(false); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("reshape_like") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/reshape_op.cpp b/oneflow/user/ops/reshape_op.cpp index 2be0133bdda..4a69084129e 100644 --- a/oneflow/user/ops/reshape_op.cpp +++ b/oneflow/user/ops/reshape_op.cpp @@ -17,12 +17,11 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/reshape_user_op_util.h" #include "oneflow/core/operator/operator.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -Maybe GetSbpFn(user_op::SbpContext* ctx) { +/*static*/ Maybe ReshapeOp::GetSbp(user_op::SbpContext* ctx) { const auto& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); const Shape& shape = ctx->Attr("shape"); const auto& outshape = JUST(ReshapeUserOpUtil::GetLogicalOutBlobShape(in_shape, shape)); @@ -31,14 +30,14 @@ Maybe GetSbpFn(user_op::SbpContext* ctx) { in_shape, *outshape, {{"in", 0}}, {{"out", 0}}, ctx->parallel_num(), &builder); } -Maybe InferNdSbpFn(user_op::InferNdSbpFnContext* ctx) { +/*static*/ Maybe ReshapeOp::InferNdSbp(user_op::InferNdSbpFnContext* ctx) { const Shape& in_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape(); const Shape& shape = ctx->user_op_conf().attr("shape"); const auto& out_shape = JUST(ReshapeUserOpUtil::GetLogicalOutBlobShape(in_shape, shape)); return ReshapeUserOpUtil::InferNdSbp(ctx, in_shape, *out_shape); } -Maybe LogicalTensorDescInferFn(user_op::InferContext* ctx) { +/*static*/ Maybe ReshapeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { Shape shape = ctx->Attr("shape"); const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out_tensor_desc = ctx->OutputTensorDesc("out", 0); @@ -70,7 +69,7 @@ Maybe LogicalTensorDescInferFn(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe TensorDescInferFn(user_op::InferContext* ctx) { +/*static*/ Maybe ReshapeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { Shape logical_shape = ctx->Attr("shape"); const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc("in", 0); user_op::TensorDesc* out_tensor_desc = ctx->OutputTensorDesc("out", 0); @@ -115,20 +114,12 @@ Maybe TensorDescInferFn(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe InferDataType(user_op::InferContext* ctx) { +/*static*/ Maybe ReshapeOp::InferDataType(user_op::InferContext* ctx) { *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); return Maybe::Ok(); } -REGISTER_USER_OP("reshape") - .Input("in") - .Output("out") - .Attr("shape") - .SetLogicalTensorDescInferFn(LogicalTensorDescInferFn) - .SetPhysicalTensorDescInferFn(TensorDescInferFn) - .SetGetSbpFn(GetSbpFn) - .SetNdSbpInferFn(InferNdSbpFn) - .SetDataTypeInferFn(InferDataType); +namespace { REGISTER_USER_OP_GRAD("reshape").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/roi_align_op.cpp b/oneflow/user/ops/roi_align_op.cpp index 58ff83e6419..15568d7a672 100644 --- a/oneflow/user/ops/roi_align_op.cpp +++ b/oneflow/user/ops/roi_align_op.cpp @@ -14,12 +14,19 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -Maybe InferRoiAlignTensorDesc(user_op::InferContext* ctx) { +/*static*/ Maybe RoiAlignOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), 0) + .Split(user_op::OpArg("rois", 0), 0) + .Split(user_op::OpArg("y", 0), 0) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe RoiAlignOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const Shape& rois_shape = ctx->InputShape("rois", 0); const int32_t pooled_h = ctx->Attr("pooled_h"); @@ -33,8 +40,34 @@ Maybe InferRoiAlignTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("y", 0) = Shape({rois_shape.At(0), x_shape.At(1), pooled_h, pooled_w}); return Maybe::Ok(); } +/*static*/ Maybe RoiAlignOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe RoiAlignOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe RoiAlignOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* roi_modifier = GetInputArgModifierFn("rois", 0); + CHECK(roi_modifier != nullptr); + roi_modifier->set_requires_grad(false); + user_op::InputArgModifier* feat_modifier = GetInputArgModifierFn("x", 0); + CHECK(feat_modifier != nullptr); + feat_modifier->set_requires_grad(true); + return Maybe::Ok(); +} -Maybe InferRoiAlignGradTensorDesc(user_op::InferContext* ctx) { +/*static*/ Maybe RoiAlignGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), 0) + .Split(user_op::OpArg("x_like", 0), 0) + .Split(user_op::OpArg("rois", 0), 0) + .Split(user_op::OpArg("dx", 0), 0) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe RoiAlignGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& dy_shape = ctx->InputShape("dy", 0); const Shape& x_like_shape = ctx->InputShape("x_like", 0); const Shape& rois_shape = ctx->InputShape("rois", 0); @@ -51,47 +84,16 @@ Maybe InferRoiAlignGradTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("dx", 0) = x_like_shape; return Maybe::Ok(); } - -Maybe InferRoiAlignDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); +/*static*/ Maybe RoiAlignGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); } - -Maybe InferRoiAlignGradDataType(user_op::InferContext* ctx) { +/*static*/ Maybe RoiAlignGradOp::InferDataType(user_op::InferContext* ctx) { CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x_like", 0)); *ctx->OutputDType("dx", 0) = ctx->InputDType("x_like", 0); return Maybe::Ok(); } -Maybe RoiAlignSbpFn(user_op::SbpContext* ctx) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), 0) - .Split(user_op::OpArg("rois", 0), 0) - .Split(user_op::OpArg("y", 0), 0) - .Build(); - return Maybe::Ok(); -} - -Maybe RoiAlignGradSbpFn(user_op::SbpContext* ctx) { - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), 0) - .Split(user_op::OpArg("x_like", 0), 0) - .Split(user_op::OpArg("rois", 0), 0) - .Split(user_op::OpArg("dx", 0), 0) - .Build(); - return Maybe::Ok(); -} - -Maybe RoiAlignArgModifier(const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) { - user_op::InputArgModifier* roi_modifier = GetInputArgModifierFn("rois", 0); - CHECK(roi_modifier != nullptr); - roi_modifier->set_requires_grad(false); - user_op::InputArgModifier* feat_modifier = GetInputArgModifierFn("x", 0); - CHECK(feat_modifier != nullptr); - feat_modifier->set_requires_grad(true); - return Maybe::Ok(); -} +namespace { Maybe GenerateBackwardOpConf4RoiAlign(const user_op::UserOpWrapper& op, const user_op::AddOpFn& AddOp) { @@ -117,34 +119,6 @@ Maybe GenerateBackwardOpConf4RoiAlign(const user_op::UserOpWrapper& op, } // namespace -REGISTER_USER_OP("roi_align") - .Input("x") - .Input("rois") - .Output("y") - .Attr("pooled_h") - .Attr("pooled_w") - .Attr("spatial_scale") - .Attr("sampling_ratio") - .Attr("aligned") - .SetTensorDescInferFn(InferRoiAlignTensorDesc) - .SetDataTypeInferFn(InferRoiAlignDataType) - .SetGetSbpFn(RoiAlignSbpFn) - .SetInputArgModifyFn(RoiAlignArgModifier); - -REGISTER_USER_OP("roi_align_grad") - .Input("dy") - .Input("x_like") - .Input("rois") - .Output("dx") - .Attr("pooled_h") - .Attr("pooled_w") - .Attr("spatial_scale") - .Attr("sampling_ratio") - .Attr("aligned") - .SetTensorDescInferFn(InferRoiAlignGradTensorDesc) - .SetDataTypeInferFn(InferRoiAlignGradDataType) - .SetGetSbpFn(RoiAlignGradSbpFn); - REGISTER_USER_OP_GRAD("roi_align").SetGenBackwardOpConfFn(GenerateBackwardOpConf4RoiAlign); } // namespace oneflow diff --git a/oneflow/user/ops/roll_op.cpp b/oneflow/user/ops/roll_op.cpp index cf86f20397c..b07077d814b 100644 --- a/oneflow/user/ops/roll_op.cpp +++ b/oneflow/user/ops/roll_op.cpp @@ -14,48 +14,47 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("roll") - .Input("in") - .Output("out") - .Attr>("shifts") - .Attr>("dims") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - *ctx->OutputShape("out", 0) = in_shape; - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - const std::vector& dims = ctx->Attr>("dims"); +/*static*/ Maybe RollOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + const std::vector& dims = ctx->Attr>("dims"); - CHECK_GT_OR_RETURN(dims.size(), 0); + CHECK_GT_OR_RETURN(dims.size(), 0); - // NOTE(Liang Depeng): (dims.size == 1 && dims[0] == -1) means that user call flow.roll with - // dims == None - if (dims[0] != -1) { - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - if (std::find(dims.begin(), dims.end(), i) == dims.end()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - } + // NOTE(Liang Depeng): (dims.size == 1 && dims[0] == -1) means that user call flow.roll with + // dims == None + if (dims[0] != -1) { + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + if (std::find(dims.begin(), dims.end(), i) == dims.end()) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); } + } + } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }); + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe RollOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + *ctx->OutputShape("out", 0) = in_shape; + return Maybe::Ok(); +} +/*static*/ Maybe RollOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe RollOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("roll").SetGenBackwardOpConfFn( [](const user_op::UserOpWrapper& op, const user_op::AddOpFn& AddOp) -> Maybe { diff --git a/oneflow/user/ops/same_padding_op.cpp b/oneflow/user/ops/same_padding_op.cpp index b54d705df65..e643232ba66 100644 --- a/oneflow/user/ops/same_padding_op.cpp +++ b/oneflow/user/ops/same_padding_op.cpp @@ -16,14 +16,26 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/balanced_splitter.h" #include "oneflow/user/ops/nn_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace user_op { -namespace { -Maybe SamePaddingTensorDescInferFn(user_op::InferContext* ctx) { - const TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); +/*static*/ Maybe SamePaddingOp::GetSbp(user_op::SbpContext* ctx) { + const int32_t num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("x_like", 0).shape().NumAxes(); + const std::string& data_format = ctx->Attr("data_format"); + ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); + const int32_t channel_idx = ChannelIdx(data_format, num_axes); + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), channel_idx) + .Split(user_op::OpArg("y", 0), channel_idx) + .Build(); + ctx->NewBuilder().PartialSum(user_op::OpArg("x", 0)).PartialSum(user_op::OpArg("y", 0)).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe SamePaddingOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); *y_desc->mut_shape() = x_desc.shape(); *y_desc->mut_is_dynamic() = x_desc.is_dynamic(); const std::string& data_format = ctx->Attr("data_format"); @@ -46,88 +58,58 @@ Maybe SamePaddingTensorDescInferFn(user_op::InferContext* ctx) { *y_desc->mut_shape() = Shape(y_dim_vec); return Maybe::Ok(); } -} // namespace - -REGISTER_USER_OP("same_padding") - .Input("x") - .Output("y") - .Attr("padding") - .Attr("data_format") - .Attr>("kernel_size") - .Attr>("strides") - .Attr>("dilation_rate") - .SetTensorDescInferFn(SamePaddingTensorDescInferFn) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const int32_t num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("x_like", 0).shape().NumAxes(); - const std::string& data_format = ctx->Attr("data_format"); - ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); - const int32_t channel_idx = ChannelIdx(data_format, num_axes); - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), channel_idx) - .Split(user_op::OpArg("y", 0), channel_idx) - .Build(); - ctx->NewBuilder() - .PartialSum(user_op::OpArg("x", 0)) - .PartialSum(user_op::OpArg("y", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe SamePaddingOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SamePaddingOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("same_padding_grad") - .Input("x_like") - .Input("dy") - .Output("dx") - .Attr("padding") - .Attr("data_format") - .Attr>("kernel_size") - .Attr>("strides") - .Attr>("dilation_rate") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("dx", 0) = ctx->InputShape("x_like", 0); - *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x_like", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const int32_t num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("x_like", 0).shape().NumAxes(); - const std::string& data_format = ctx->Attr("data_format"); - ctx->NewBuilder() - .Split(user_op::OpArg("x_like", 0), 0) - .Split(user_op::OpArg("dy", 0), 0) - .Split(user_op::OpArg("dx", 0), 0) - .Build(); - const int32_t channel_idx = ChannelIdx(data_format, num_axes); - ctx->NewBuilder() - .Split(user_op::OpArg("x_like", 0), channel_idx) - .Split(user_op::OpArg("dy", 0), channel_idx) - .Split(user_op::OpArg("dx", 0), channel_idx) - .Build(); - ctx->NewBuilder() - .PartialSum(user_op::OpArg("x_like", 0)) - .PartialSum(user_op::OpArg("dy", 0)) - .PartialSum(user_op::OpArg("dx", 0)) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("x_like", 0)) - .PartialSum(user_op::OpArg("dy", 0)) - .PartialSum(user_op::OpArg("dx", 0)) - .Build(); - ctx->NewBuilder() - .PartialSum(user_op::OpArg("x_like", 0)) - .Broadcast(user_op::OpArg("dy", 0)) - .Broadcast(user_op::OpArg("dx", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("x_like", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe SamePaddingGradOp::GetSbp(user_op::SbpContext* ctx) { + const int32_t num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("x_like", 0).shape().NumAxes(); + const std::string& data_format = ctx->Attr("data_format"); + ctx->NewBuilder() + .Split(user_op::OpArg("x_like", 0), 0) + .Split(user_op::OpArg("dy", 0), 0) + .Split(user_op::OpArg("dx", 0), 0) + .Build(); + const int32_t channel_idx = ChannelIdx(data_format, num_axes); + ctx->NewBuilder() + .Split(user_op::OpArg("x_like", 0), channel_idx) + .Split(user_op::OpArg("dy", 0), channel_idx) + .Split(user_op::OpArg("dx", 0), channel_idx) + .Build(); + ctx->NewBuilder() + .PartialSum(user_op::OpArg("x_like", 0)) + .PartialSum(user_op::OpArg("dy", 0)) + .PartialSum(user_op::OpArg("dx", 0)) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("x_like", 0)) + .PartialSum(user_op::OpArg("dy", 0)) + .PartialSum(user_op::OpArg("dx", 0)) + .Build(); + ctx->NewBuilder() + .PartialSum(user_op::OpArg("x_like", 0)) + .Broadcast(user_op::OpArg("dy", 0)) + .Broadcast(user_op::OpArg("dx", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe SamePaddingGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("dx", 0) = ctx->InputShape("x_like", 0); + *ctx->OutputIsDynamic("dx", 0) = ctx->InputIsDynamic("x_like", 0); + return Maybe::Ok(); +} +/*static*/ Maybe SamePaddingGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SamePaddingGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("x_like", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("same_padding") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -156,5 +138,4 @@ REGISTER_USER_OP_GRAD("same_padding") return Maybe::Ok(); }); -} // namespace user_op } // namespace oneflow diff --git a/oneflow/user/ops/scalar_by_tensor_op.cpp b/oneflow/user/ops/scalar_by_tensor_op.cpp index 0ec8c0adfe4..f5420517a67 100644 --- a/oneflow/user/ops/scalar_by_tensor_op.cpp +++ b/oneflow/user/ops/scalar_by_tensor_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -61,20 +62,82 @@ GetSbpFn MakeGetSbpFn(GetSbpFn extra) { } // namespace -REGISTER_USER_OP("scalar_add_by_tensor") - .Input("x") - .Input("scalar") - .Output("y") - .SetTensorDescInferFn(TensorDescInferFn) - .SetDataTypeInferFn(DataTypeInferFn) - .SetGetSbpFn(MakeGetSbpFn([](user_op::SbpContext* ctx) { - ctx->NewBuilder() - .PartialSum(user_op::OpArg("x", 0)) - .PartialSum(user_op::OpArg("scalar", 0)) - .PartialSum(user_op::OpArg("y", 0)) - .Build(); - return Maybe::Ok(); - })); +/*static*/ Maybe ScalarAddByTensorOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .PartialSum(user_op::OpArg("x", 0)) + .PartialSum(user_op::OpArg("scalar", 0)) + .PartialSum(user_op::OpArg("y", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe ScalarAddByTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return TensorDescInferFn(ctx); +} +/*static*/ Maybe ScalarAddByTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ScalarAddByTensorOp::InferDataType(user_op::InferContext* ctx) { + return DataTypeInferFn(ctx); +} + +/*static*/ Maybe ScalarSubByTensorOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .PartialSum(user_op::OpArg("x", 0)) + .PartialSum(user_op::OpArg("scalar", 0)) + .PartialSum(user_op::OpArg("y", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe ScalarSubByTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return TensorDescInferFn(ctx); +} +/*static*/ Maybe ScalarSubByTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ScalarSubByTensorOp::InferDataType(user_op::InferContext* ctx) { + return DataTypeInferFn(ctx); +} + +/*static*/ Maybe ScalarMulByTensorOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .PartialSum(user_op::OpArg("x", 0)) + .Broadcast(user_op::OpArg("scalar", 0)) + .PartialSum(user_op::OpArg("y", 0)) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("x", 0)) + .PartialSum(user_op::OpArg("scalar", 0)) + .PartialSum(user_op::OpArg("y", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe ScalarMulByTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return TensorDescInferFn(ctx); +} +/*static*/ Maybe ScalarMulByTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ScalarMulByTensorOp::InferDataType(user_op::InferContext* ctx) { + return DataTypeInferFn(ctx); +} + +/*static*/ Maybe ScalarDivByTensorOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .PartialSum(user_op::OpArg("x", 0)) + .Broadcast(user_op::OpArg("scalar", 0)) + .PartialSum(user_op::OpArg("y", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe ScalarDivByTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return TensorDescInferFn(ctx); +} +/*static*/ Maybe ScalarDivByTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ScalarDivByTensorOp::InferDataType(user_op::InferContext* ctx) { + return DataTypeInferFn(ctx); +} REGISTER_USER_OP_GRAD("scalar_add_by_tensor") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -99,21 +162,6 @@ REGISTER_USER_OP_GRAD("scalar_add_by_tensor") return Maybe::Ok(); }); -REGISTER_USER_OP("scalar_sub_by_tensor") - .Input("x") - .Input("scalar") - .Output("y") - .SetTensorDescInferFn(TensorDescInferFn) - .SetDataTypeInferFn(DataTypeInferFn) - .SetGetSbpFn(MakeGetSbpFn([](user_op::SbpContext* ctx) { - ctx->NewBuilder() - .PartialSum(user_op::OpArg("x", 0)) - .PartialSum(user_op::OpArg("scalar", 0)) - .PartialSum(user_op::OpArg("y", 0)) - .Build(); - return Maybe::Ok(); - })); - REGISTER_USER_OP_GRAD("scalar_sub_by_tensor") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { @@ -148,26 +196,6 @@ REGISTER_USER_OP_GRAD("scalar_sub_by_tensor") return Maybe::Ok(); }); -REGISTER_USER_OP("scalar_mul_by_tensor") - .Input("x") - .Input("scalar") - .Output("y") - .SetTensorDescInferFn(TensorDescInferFn) - .SetDataTypeInferFn(DataTypeInferFn) - .SetGetSbpFn(MakeGetSbpFn([](user_op::SbpContext* ctx) { - ctx->NewBuilder() - .PartialSum(user_op::OpArg("x", 0)) - .Broadcast(user_op::OpArg("scalar", 0)) - .PartialSum(user_op::OpArg("y", 0)) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("x", 0)) - .PartialSum(user_op::OpArg("scalar", 0)) - .PartialSum(user_op::OpArg("y", 0)) - .Build(); - return Maybe::Ok(); - })); - REGISTER_USER_OP_GRAD("scalar_mul_by_tensor") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { @@ -208,21 +236,6 @@ REGISTER_USER_OP_GRAD("scalar_mul_by_tensor") return Maybe::Ok(); }); -REGISTER_USER_OP("scalar_div_by_tensor") - .Input("x") - .Input("scalar") - .Output("y") - .SetTensorDescInferFn(TensorDescInferFn) - .SetDataTypeInferFn(DataTypeInferFn) - .SetGetSbpFn(MakeGetSbpFn([](user_op::SbpContext* ctx) { - ctx->NewBuilder() - .PartialSum(user_op::OpArg("x", 0)) - .Broadcast(user_op::OpArg("scalar", 0)) - .PartialSum(user_op::OpArg("y", 0)) - .Build(); - return Maybe::Ok(); - })); - REGISTER_USER_OP_GRAD("scalar_div_by_tensor") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/scalar_logical_op.cpp b/oneflow/user/ops/scalar_logical_op.cpp index 7f33dfa66c4..7bd176790ee 100644 --- a/oneflow/user/ops/scalar_logical_op.cpp +++ b/oneflow/user/ops/scalar_logical_op.cpp @@ -14,42 +14,39 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -#define REGISTER_SCALAR_LOGICAL_OP(op_name) \ - REGISTER_NO_GRAD_USER_OP(op_name) \ - .Input("in") \ - .Output("out") \ - .Attr("has_int_operand") \ - .Attr("has_float_operand") \ - .Attr("int_operand") \ - .Attr("float_operand") \ - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { \ - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); \ - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); \ - return Maybe::Ok(); \ - }) \ - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { \ - const user_op::TensorDesc& in_tensor = \ - ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); \ - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { \ - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); \ - } \ - return Maybe::Ok(); \ - }) \ - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { \ - *ctx->OutputDType("out", 0) = DataType::kInt8; \ - return Maybe::Ok(); \ - }); +#define IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(name) \ + /*static*/ Maybe name##Op::GetSbp(user_op::SbpContext* ctx) { \ + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); \ + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { \ + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); \ + } \ + return Maybe::Ok(); \ + } \ + /*static*/ Maybe name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); \ + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); \ + return Maybe::Ok(); \ + } \ + /*static*/ Maybe name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe name##Op::InferDataType(user_op::InferContext* ctx) { \ + *ctx->OutputDType("out", 0) = DataType::kInt8; \ + return Maybe::Ok(); \ + } + +IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalEqual); +IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalNotEqual); +IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalGreater); +IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalGreaterEqual); +IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalLess); +IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalLessEqual); +IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalAnd); +IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalOr); +IMPLEMENT_SCALAR_LOGICAL_OP_FUNCS(ScalarLogicalXor); -REGISTER_SCALAR_LOGICAL_OP("scalar_logical_equal"); -REGISTER_SCALAR_LOGICAL_OP("scalar_logical_not_equal"); -REGISTER_SCALAR_LOGICAL_OP("scalar_logical_greater"); -REGISTER_SCALAR_LOGICAL_OP("scalar_logical_greater_equal"); -REGISTER_SCALAR_LOGICAL_OP("scalar_logical_less"); -REGISTER_SCALAR_LOGICAL_OP("scalar_logical_less_equal"); -REGISTER_SCALAR_LOGICAL_OP("scalar_logical_and"); -REGISTER_SCALAR_LOGICAL_OP("scalar_logical_or"); -REGISTER_SCALAR_LOGICAL_OP("scalar_logical_xor"); } // namespace oneflow diff --git a/oneflow/user/ops/scalar_math_op.cpp b/oneflow/user/ops/scalar_math_op.cpp index 82950d50131..3c20827281d 100644 --- a/oneflow/user/ops/scalar_math_op.cpp +++ b/oneflow/user/ops/scalar_math_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -38,55 +39,47 @@ Maybe GetSbp4ScalarMul(user_op::SbpContext* ctx) { } // namespace -#define REGISTER_SCALAR_MATH_OP(op_name, get_sbp_fn) \ - REGISTER_USER_OP(op_name) \ - .Input("in") \ - .Output("out") \ - .Attr("has_int_operand") \ - .Attr("has_float_operand") \ - .Attr("int_operand") \ - .Attr("float_operand") \ - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { \ - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); \ - *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); \ - return Maybe::Ok(); \ - }) \ - .SetGetSbpFn(get_sbp_fn) \ - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { \ - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); \ - return Maybe::Ok(); \ - }); +#define IMPLEMENT_SCALAR_MATH_OP_FUNCS(op_name, get_sbp_fn) \ + /*static*/ Maybe op_name##Op::GetSbp(user_op::SbpContext* ctx) { return get_sbp_fn(ctx); } \ + /*static*/ Maybe op_name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); \ + *ctx->OutputIsDynamic("out", 0) = ctx->InputIsDynamic("in", 0); \ + return Maybe::Ok(); \ + } \ + /*static*/ Maybe op_name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe op_name##Op::InferDataType(user_op::InferContext* ctx) { \ + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); \ + return Maybe::Ok(); \ + } -REGISTER_SCALAR_MATH_OP("scalar_add", GetSbp4ScalarMath) -REGISTER_SCALAR_MATH_OP("scalar_floordiv", GetSbp4ScalarMath) -REGISTER_SCALAR_MATH_OP("scalar_fmod", GetSbp4ScalarMath) -REGISTER_SCALAR_MATH_OP("scalar_mul", GetSbp4ScalarMul) -REGISTER_SCALAR_MATH_OP("scalar_pow", GetSbp4ScalarMath) +IMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarAdd, GetSbp4ScalarMath) +IMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarFloordiv, GetSbp4ScalarMath) +IMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarFmod, GetSbp4ScalarMath) +IMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarMul, GetSbp4ScalarMul) +IMPLEMENT_SCALAR_MATH_OP_FUNCS(ScalarPow, GetSbp4ScalarMath) +#undef IMPLEMENT_SCALAR_MATH_OP_FUNCS -REGISTER_USER_OP("scalar_pow_grad") - .Input("x") - .Input("dy") - .Attr("has_int_operand") - .Attr("has_float_operand") - .Attr("int_operand") - .Attr("float_operand") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("dx", 0) = ctx->InputShape("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe ScalarPowGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe ScalarPowGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("dx", 0) = ctx->InputShape("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe ScalarPowGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ScalarPowGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("x", 0), ctx->InputDType("dy", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("scalar_add") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/selu_op.cpp b/oneflow/user/ops/selu_op.cpp index 8697cbf39a4..8ed852eb395 100644 --- a/oneflow/user/ops/selu_op.cpp +++ b/oneflow/user/ops/selu_op.cpp @@ -14,61 +14,58 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ Maybe SeluOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SeluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe SeluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SeluOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("selu") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe SeluGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SeluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK(dy_shape == x_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} +/*static*/ Maybe SeluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SeluGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("selu_grad") - .Input("x") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK(dy_shape == x_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +namespace { REGISTER_USER_OP_GRAD("selu").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { const auto selu_grad_op_name = ctx->FwOp().op_name() + "_grad"; diff --git a/oneflow/user/ops/sigmoid_cross_entropy_op.cpp b/oneflow/user/ops/sigmoid_cross_entropy_op.cpp index cdf7e78870b..1928b5ab5c2 100644 --- a/oneflow/user/ops/sigmoid_cross_entropy_op.cpp +++ b/oneflow/user/ops/sigmoid_cross_entropy_op.cpp @@ -14,84 +14,84 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("sigmoid_cross_entropy") - .Input("prediction") - .Input("label") - .Output("loss") - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("label", 0); - cond_arg_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); - const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); - CHECK_EQ_OR_RETURN(label_desc.shape(), prediction_desc.shape()); - user_op::TensorDesc* loss_desc = ctx->OutputTensorDesc("loss", 0); - *loss_desc->mut_shape() = prediction_desc.shape(); - *loss_desc->mut_is_dynamic() = prediction_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto num_out_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("prediction", 0).shape().NumAxes(); - FOR_RANGE(int64_t, i, 0, num_out_axes) { - ctx->NewBuilder() - .Split(user_op::OpArg("prediction", 0), i) - .Split(user_op::OpArg("label", 0), i) - .Split(user_op::OpArg("loss", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("loss", 0) = ctx->InputDType("prediction", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe SigmoidCrossEntropyOp::GetSbp(user_op::SbpContext* ctx) { + const auto num_out_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("prediction", 0).shape().NumAxes(); + FOR_RANGE(int64_t, i, 0, num_out_axes) { + ctx->NewBuilder() + .Split(user_op::OpArg("prediction", 0), i) + .Split(user_op::OpArg("label", 0), i) + .Split(user_op::OpArg("loss", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SigmoidCrossEntropyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); + const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); + CHECK_EQ_OR_RETURN(label_desc.shape(), prediction_desc.shape()); + user_op::TensorDesc* loss_desc = ctx->OutputTensorDesc("loss", 0); + *loss_desc->mut_shape() = prediction_desc.shape(); + *loss_desc->mut_is_dynamic() = prediction_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ Maybe SigmoidCrossEntropyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SigmoidCrossEntropyOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("loss", 0) = ctx->InputDType("prediction", 0); + return Maybe::Ok(); +} +/*static*/ Maybe SigmoidCrossEntropyOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("label", 0); + cond_arg_modifier->set_requires_grad(false); + return Maybe::Ok(); +} -REGISTER_USER_OP("sigmoid_cross_entropy_grad") - .Input("prediction") - .Input("loss_diff") - .Input("label") - .Output("prediction_diff") - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("label", 0); - cond_arg_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); - const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); - const user_op::TensorDesc& loss_diff_desc = ctx->InputTensorDesc("loss_diff", 0); - CHECK_EQ_OR_RETURN(label_desc.shape(), prediction_desc.shape()); - CHECK_EQ_OR_RETURN(loss_diff_desc.shape(), prediction_desc.shape()); - user_op::TensorDesc* prediction_diff = ctx->OutputTensorDesc("prediction_diff", 0); - *prediction_diff->mut_shape() = prediction_desc.shape(); - *prediction_diff->mut_is_dynamic() = prediction_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto num_dy_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("loss_diff", 0).shape().NumAxes(); - FOR_RANGE(int64_t, i, 0, num_dy_axes) { - ctx->NewBuilder() - .Split(user_op::OpArg("loss_diff", 0), i) - .Split(user_op::OpArg("label", 0), i) - .Split(user_op::OpArg("prediction", 0), i) - .Split(user_op::OpArg("prediction_diff", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("prediction_diff", 0) = ctx->InputDType("prediction", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe SigmoidCrossEntropyGradOp::GetSbp(user_op::SbpContext* ctx) { + const auto num_dy_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("loss_diff", 0).shape().NumAxes(); + FOR_RANGE(int64_t, i, 0, num_dy_axes) { + ctx->NewBuilder() + .Split(user_op::OpArg("loss_diff", 0), i) + .Split(user_op::OpArg("label", 0), i) + .Split(user_op::OpArg("prediction", 0), i) + .Split(user_op::OpArg("prediction_diff", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SigmoidCrossEntropyGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); + const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); + const user_op::TensorDesc& loss_diff_desc = ctx->InputTensorDesc("loss_diff", 0); + CHECK_EQ_OR_RETURN(label_desc.shape(), prediction_desc.shape()); + CHECK_EQ_OR_RETURN(loss_diff_desc.shape(), prediction_desc.shape()); + user_op::TensorDesc* prediction_diff = ctx->OutputTensorDesc("prediction_diff", 0); + *prediction_diff->mut_shape() = prediction_desc.shape(); + *prediction_diff->mut_is_dynamic() = prediction_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ Maybe SigmoidCrossEntropyGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SigmoidCrossEntropyGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("prediction_diff", 0) = ctx->InputDType("prediction", 0); + return Maybe::Ok(); +} +/*static*/ Maybe SigmoidCrossEntropyGradOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("label", 0); + cond_arg_modifier->set_requires_grad(false); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("sigmoid_cross_entropy") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/sigmoid_op.cpp b/oneflow/user/ops/sigmoid_op.cpp index 3af60af6440..f45506bc723 100644 --- a/oneflow/user/ops/sigmoid_op.cpp +++ b/oneflow/user/ops/sigmoid_op.cpp @@ -14,63 +14,60 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ Maybe SigmoidOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SigmoidOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + Shape* out_shape = ctx->OutputShape("out", 0); + *out_shape = in_shape; + return Maybe::Ok(); +} +/*static*/ Maybe SigmoidOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SigmoidOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("sigmoid") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); - *out_shape = in_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe SigmoidGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); + FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("y", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SigmoidGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& y_shape = ctx->InputShape("y", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(dy_shape == y_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} +/*static*/ Maybe SigmoidGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SigmoidGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("y", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("y", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("sigmoid_grad") - .Input("y") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& y_shape = ctx->InputShape("y", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(dy_shape == y_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); - FOR_RANGE(int64_t, i, 0, y_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("y", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("y", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("y", 0); - return Maybe::Ok(); - }); +namespace { REGISTER_USER_OP_GRAD("sigmoid").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/silu_op.cpp b/oneflow/user/ops/silu_op.cpp index eb46ab7b406..59c9831bf29 100644 --- a/oneflow/user/ops/silu_op.cpp +++ b/oneflow/user/ops/silu_op.cpp @@ -14,61 +14,58 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ Maybe SiluOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SiluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe SiluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SiluOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("silu") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe SiluGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SiluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK(dy_shape == x_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} +/*static*/ Maybe SiluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SiluGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("silu_grad") - .Input("x") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK(dy_shape == x_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +namespace { REGISTER_USER_OP_GRAD("silu").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { const auto silu_grad_op_name = ctx->FwOp().op_name() + "_grad"; diff --git a/oneflow/user/ops/slice_op.cpp b/oneflow/user/ops/slice_op.cpp index 3d5cc31e2b0..0f42d098420 100644 --- a/oneflow/user/ops/slice_op.cpp +++ b/oneflow/user/ops/slice_op.cpp @@ -15,19 +15,38 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/slice_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { namespace { - bool IsFullSlice(int64_t start, int64_t stop, int64_t step, int64_t size) { if (step != 1) { return false; } if (start != 0) { return false; } if (stop != size) { return false; } return true; } +} // namespace -Maybe InferSliceOpTensorDesc(user_op::InferContext* ctx) { +/*static*/ Maybe SliceOp::GetSbp(user_op::SbpContext* ctx) { + const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); + const int64_t ndim = x_shape.NumAxes(); + const auto& start_vec = ctx->Attr>("start"); + const auto& stop_vec = ctx->Attr>("stop"); + const auto& step_vec = ctx->Attr>("step"); + CHECK_EQ_OR_RETURN(start_vec.size(), ndim); + CHECK_EQ_OR_RETURN(stop_vec.size(), ndim); + CHECK_EQ_OR_RETURN(step_vec.size(), ndim); + + FOR_RANGE(int, i, 0, ndim) { + if (IsFullSlice(start_vec.at(i), stop_vec.at(i), step_vec.at(i), x_shape.At(i))) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + } + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe SliceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& x_shape = ctx->InputShape("x", 0); const int64_t ndim = x_shape.NumAxes(); const auto& start_vec = ctx->Attr>("start"); @@ -63,15 +82,17 @@ Maybe InferSliceOpTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("y", 0) = Shape(dim_vec); return Maybe::Ok(); } - -Maybe InferSliceOpDataType(user_op::InferContext* ctx) { +/*static*/ Maybe SliceOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SliceOp::InferDataType(user_op::InferContext* ctx) { *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); return Maybe::Ok(); } -Maybe GetSliceOpSbpSignature(user_op::SbpContext* ctx) { - const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); - const int64_t ndim = x_shape.NumAxes(); +/*static*/ Maybe SliceGradOp::GetSbp(user_op::SbpContext* ctx) { + const Shape& like_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape(); + const int64_t ndim = like_shape.NumAxes(); const auto& start_vec = ctx->Attr>("start"); const auto& stop_vec = ctx->Attr>("stop"); const auto& step_vec = ctx->Attr>("step"); @@ -80,15 +101,24 @@ Maybe GetSliceOpSbpSignature(user_op::SbpContext* ctx) { CHECK_EQ_OR_RETURN(step_vec.size(), ndim); FOR_RANGE(int, i, 0, ndim) { - if (IsFullSlice(start_vec.at(i), stop_vec.at(i), step_vec.at(i), x_shape.At(i))) { + if (IsFullSlice(start_vec.at(i), stop_vec.at(i), step_vec.at(i), like_shape.At(i))) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } } ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); + ctx->NewBuilder() + .PartialSum(user_op::OpArg("dy", 0)) + .Broadcast(user_op::OpArg("like", 0)) + .PartialSum(user_op::OpArg("dx", 0)) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("dy", 0)) + .PartialSum(user_op::OpArg("like", 0)) + .Broadcast(user_op::OpArg("dx", 0)) + .Build(); return Maybe::Ok(); } - -Maybe InferSliceGradOpTensorDesc(user_op::InferContext* ctx) { +/*static*/ Maybe SliceGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const Shape& like_shape = ctx->InputShape("like", 0); const Shape& dy_shape = ctx->InputShape("dy", 0); const auto& start_vec = ctx->Attr>("start"); @@ -103,15 +133,121 @@ Maybe InferSliceGradOpTensorDesc(user_op::InferContext* ctx) { *ctx->OutputShape("dx", 0) = like_shape; return Maybe::Ok(); } - -Maybe InferSliceGradDataType(user_op::InferContext* ctx) { +/*static*/ Maybe SliceGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SliceGradOp::InferDataType(user_op::InferContext* ctx) { *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); return Maybe::Ok(); } +/*static*/ Maybe SliceGradOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* dy_modifier = GetInputArgModifierFn("dy", 0); + CHECK_NOTNULL_OR_RETURN(dy_modifier); + dy_modifier->set_requires_grad(false); + user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); + CHECK_NOTNULL_OR_RETURN(like_modifier); + like_modifier->set_requires_grad(false); + return Maybe::Ok(); +} -Maybe GetSliceGradOpSbpSignature(user_op::SbpContext* ctx) { - const Shape& like_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape(); - const int64_t ndim = like_shape.NumAxes(); +/*static*/ Maybe LogicalSliceAssignOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& ref_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("ref", 0); + FOR_RANGE(int64_t, axis, 0, ref_desc.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("ref", 0), axis) + // TODO(jianhao): Support (S(n), S(n)) when axis n is not sliced + .Broadcast(user_op::OpArg("value", 0)) + .Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("ref", 0)) + .PartialSum(user_op::OpArg("value", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe LogicalSliceAssignOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc("ref", 0); + const auto& start_vec = ctx->Attr>("start"); + const auto& stop_vec = ctx->Attr>("stop"); + const auto& step_vec = ctx->Attr>("step"); + CHECK_OR_RETURN(!ref_desc.is_dynamic()); + FOR_RANGE(size_t, i, 0, step_vec.size()) { + const int64_t step = step_vec.at(i); + const int64_t start = start_vec.at(i); + const int64_t stop = stop_vec.at(i); + CHECK_GT_OR_RETURN(step, 0) << "logical_slice_assign step must be greater than 0"; + CHECK_GE_OR_RETURN(start, 0) << "logical_slice_assign start must be greater or equal to 0"; + CHECK_GT_OR_RETURN(stop, 0) << "logical_slice_assign stop must be greater than 0"; + CHECK_LT_OR_RETURN(start, stop) << "logical_slice_assign start must be less than stop"; + } + return Maybe::Ok(); +} +/*static*/ Maybe LogicalSliceAssignOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe LogicalSliceAssignOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc("ref", 0); + const user_op::TensorDesc& value_desc = ctx->InputTensorDesc("value", 0); + CHECK_OR_RETURN(ref_desc.data_type() == value_desc.data_type()); + return Maybe::Ok(); +} + +/*static*/ Maybe LogicalSliceAssignOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* ref_modifier = GetInputArgModifierFn("ref", 0); + CHECK_OR_RETURN(ref_modifier != nullptr); + ref_modifier->set_is_mutable(true); + user_op::InputArgModifier* value_modifier = GetInputArgModifierFn("value", 0); + CHECK_OR_RETURN(value_modifier != nullptr); + value_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/*static*/ Maybe LogicalSliceOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& input_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, axis, 0, input_desc.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), axis) + // TODO(jianhao): Support S(n) -> S(n) when axis n is not sliced + .PartialSum(user_op::OpArg("y", 0)) + .Build(); + } + ctx->NewBuilder().PartialSum(user_op::OpArg("x", 0)).PartialSum(user_op::OpArg("y", 0)).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe LogicalSliceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const int64_t ndim = x_shape.NumAxes(); + const auto& start_vec = ctx->Attr>("start"); + const auto& stop_vec = ctx->Attr>("stop"); + const auto& step_vec = ctx->Attr>("step"); + DimVector dim_vec(ndim); + FOR_RANGE(size_t, i, 0, dim_vec.size()) { + const int64_t step = step_vec.at(i); + const int64_t start = start_vec.at(i); + const int64_t stop = stop_vec.at(i); + CHECK_GT_OR_RETURN(step, 0) << "LogicalSlice step must be greater than 0"; + CHECK_GE_OR_RETURN(start, 0) << "LogicalSlice start must be greater or equal to 0"; + CHECK_GT_OR_RETURN(stop, 0) << "LogicalSlice stop must be greater than 0"; + CHECK_LT_OR_RETURN(start, stop) << "LogicalSlice start must be less than stop"; + const int64_t diff = stop - start - 1; + dim_vec[i] = diff / step + 1; + } + *ctx->OutputShape("y", 0) = Shape(dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe LogicalSliceOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe LogicalSliceOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} + +/*static*/ Maybe SliceUpdateOp::GetSbp(user_op::SbpContext* ctx) { + const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); + const int64_t ndim = x_shape.NumAxes(); const auto& start_vec = ctx->Attr>("start"); const auto& stop_vec = ctx->Attr>("stop"); const auto& step_vec = ctx->Attr>("step"); @@ -120,36 +256,14 @@ Maybe GetSliceGradOpSbpSignature(user_op::SbpContext* ctx) { CHECK_EQ_OR_RETURN(step_vec.size(), ndim); FOR_RANGE(int, i, 0, ndim) { - if (IsFullSlice(start_vec.at(i), stop_vec.at(i), step_vec.at(i), like_shape.At(i))) { + if (IsFullSlice(start_vec.at(i), stop_vec.at(i), step_vec.at(i), x_shape.At(i))) { ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); } } ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); - ctx->NewBuilder() - .PartialSum(user_op::OpArg("dy", 0)) - .Broadcast(user_op::OpArg("like", 0)) - .PartialSum(user_op::OpArg("dx", 0)) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("dy", 0)) - .PartialSum(user_op::OpArg("like", 0)) - .Broadcast(user_op::OpArg("dx", 0)) - .Build(); - return Maybe::Ok(); -} - -Maybe InferSliceGradInputArgModifier(user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper& conf) { - user_op::InputArgModifier* dy_modifier = GetInputArgModifierFn("dy", 0); - CHECK_NOTNULL_OR_RETURN(dy_modifier); - dy_modifier->set_requires_grad(false); - user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); - CHECK_NOTNULL_OR_RETURN(like_modifier); - like_modifier->set_requires_grad(false); return Maybe::Ok(); } - -Maybe InferSliceUpdateOpTensorDesc(user_op::InferContext* ctx) { +/*static*/ Maybe SliceUpdateOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const auto& x_desc = ctx->InputTensorDesc("x", 0); const int64_t ndim = x_desc.shape().NumAxes(); const auto& update_desc = ctx->InputTensorDesc("update", 0); @@ -185,8 +299,10 @@ Maybe InferSliceUpdateOpTensorDesc(user_op::InferContext* ctx) { *y_desc->mut_is_dynamic() = x_desc.is_dynamic(); return Maybe::Ok(); } - -Maybe InferSliceUpdateOpDataType(user_op::InferContext* ctx) { +/*static*/ Maybe SliceUpdateOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SliceUpdateOp::InferDataType(user_op::InferContext* ctx) { const auto& x_desc = ctx->InputTensorDesc("x", 0); const auto& update_desc = ctx->InputTensorDesc("update", 0); CHECK_EQ_OR_RETURN(update_desc.data_type(), x_desc.data_type()); @@ -195,24 +311,7 @@ Maybe InferSliceUpdateOpDataType(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe GetSliceUpdateOpSbpSignature(user_op::SbpContext* ctx) { - const Shape& x_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape(); - const int64_t ndim = x_shape.NumAxes(); - const auto& start_vec = ctx->Attr>("start"); - const auto& stop_vec = ctx->Attr>("stop"); - const auto& step_vec = ctx->Attr>("step"); - CHECK_EQ_OR_RETURN(start_vec.size(), ndim); - CHECK_EQ_OR_RETURN(stop_vec.size(), ndim); - CHECK_EQ_OR_RETURN(step_vec.size(), ndim); - - FOR_RANGE(int, i, 0, ndim) { - if (IsFullSlice(start_vec.at(i), stop_vec.at(i), step_vec.at(i), x_shape.At(i))) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - } - ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); - return Maybe::Ok(); -} +namespace { Maybe GenSliceGradOp(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) { if (op.NeedGenGradTensor4OpInput("x", 0)) { @@ -231,98 +330,6 @@ Maybe GenSliceGradOp(const user_op::UserOpWrapper& op, user_op::AddOpFn Ad return Maybe::Ok(); } -Maybe InferLogicalSliceAssignTensorDesc(user_op::InferContext* ctx) { - const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc("ref", 0); - const auto& start_vec = ctx->Attr>("start"); - const auto& stop_vec = ctx->Attr>("stop"); - const auto& step_vec = ctx->Attr>("step"); - CHECK_OR_RETURN(!ref_desc.is_dynamic()); - FOR_RANGE(size_t, i, 0, step_vec.size()) { - const int64_t step = step_vec.at(i); - const int64_t start = start_vec.at(i); - const int64_t stop = stop_vec.at(i); - CHECK_GT_OR_RETURN(step, 0) << "logical_slice_assign step must be greater than 0"; - CHECK_GE_OR_RETURN(start, 0) << "logical_slice_assign start must be greater or equal to 0"; - CHECK_GT_OR_RETURN(stop, 0) << "logical_slice_assign stop must be greater than 0"; - CHECK_LT_OR_RETURN(start, stop) << "logical_slice_assign start must be less than stop"; - } - return Maybe::Ok(); -} - -Maybe InferLogicalSliceAssignDataType(user_op::InferContext* ctx) { - const user_op::TensorDesc& ref_desc = ctx->InputTensorDesc("ref", 0); - const user_op::TensorDesc& value_desc = ctx->InputTensorDesc("value", 0); - CHECK_OR_RETURN(ref_desc.data_type() == value_desc.data_type()); - return Maybe::Ok(); -} - -Maybe GetLogicalSliceAssignSbpSignatures(user_op::SbpContext* ctx) { - const user_op::TensorDesc& ref_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("ref", 0); - FOR_RANGE(int64_t, axis, 0, ref_desc.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("ref", 0), axis) - // TODO(jianhao): Support (S(n), S(n)) when axis n is not sliced - .Broadcast(user_op::OpArg("value", 0)) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("ref", 0)) - .PartialSum(user_op::OpArg("value", 0)) - .Build(); - return Maybe::Ok(); -} - -Maybe InferLogicalSliceAssignInputArgModifier( - user_op::GetInputArgModifier GetInputArgModifierFn, const user_op::UserOpConfWrapper& conf) { - user_op::InputArgModifier* ref_modifier = GetInputArgModifierFn("ref", 0); - CHECK_OR_RETURN(ref_modifier != nullptr); - ref_modifier->set_is_mutable(true); - user_op::InputArgModifier* value_modifier = GetInputArgModifierFn("value", 0); - CHECK_OR_RETURN(value_modifier != nullptr); - value_modifier->set_requires_grad(false); - return Maybe::Ok(); -} - -Maybe InferLogicalSliceTensorDesc(user_op::InferContext* ctx) { - const Shape& x_shape = ctx->InputShape("x", 0); - const int64_t ndim = x_shape.NumAxes(); - const auto& start_vec = ctx->Attr>("start"); - const auto& stop_vec = ctx->Attr>("stop"); - const auto& step_vec = ctx->Attr>("step"); - DimVector dim_vec(ndim); - FOR_RANGE(size_t, i, 0, dim_vec.size()) { - const int64_t step = step_vec.at(i); - const int64_t start = start_vec.at(i); - const int64_t stop = stop_vec.at(i); - CHECK_GT_OR_RETURN(step, 0) << "LogicalSlice step must be greater than 0"; - CHECK_GE_OR_RETURN(start, 0) << "LogicalSlice start must be greater or equal to 0"; - CHECK_GT_OR_RETURN(stop, 0) << "LogicalSlice stop must be greater than 0"; - CHECK_LT_OR_RETURN(start, stop) << "LogicalSlice start must be less than stop"; - const int64_t diff = stop - start - 1; - dim_vec[i] = diff / step + 1; - } - *ctx->OutputShape("y", 0) = Shape(dim_vec); - return Maybe::Ok(); -} - -Maybe InferLogicalSliceDataType(user_op::InferContext* ctx) { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); -} - -Maybe GetLogicalSliceSbpSignatures(user_op::SbpContext* ctx) { - const user_op::TensorDesc& input_desc = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, axis, 0, input_desc.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), axis) - // TODO(jianhao): Support S(n) -> S(n) when axis n is not sliced - .PartialSum(user_op::OpArg("y", 0)) - .Build(); - } - ctx->NewBuilder().PartialSum(user_op::OpArg("x", 0)).PartialSum(user_op::OpArg("y", 0)).Build(); - return Maybe::Ok(); -} - Maybe GenSliceUpdateGradOp(user_op::BackwardOpConfContext* ctx) { const std::string update_grad_op_name = ctx->FwOp().op_name() + "_update_grad"; ctx->DefineOp(update_grad_op_name, [&](user_op::BackwardOpBuilder& builder) { @@ -364,62 +371,7 @@ Maybe GenSliceUpdateGradOp(user_op::BackwardOpConfContext* ctx) { } // namespace -REGISTER_USER_OP("slice") - .Input("x") - .Output("y") - .Attr>("start") - .Attr>("stop") - .Attr>("step") - .SetTensorDescInferFn(InferSliceOpTensorDesc) - .SetDataTypeInferFn(InferSliceOpDataType) - .SetGetSbpFn(GetSliceOpSbpSignature); - -REGISTER_USER_OP("slice_grad") - .Input("dy") - .Input("like") - .Output("dx") - .Attr>("start") - .Attr>("stop") - .Attr>("step") - .SetTensorDescInferFn(InferSliceGradOpTensorDesc) - .SetDataTypeInferFn(InferSliceGradDataType) - .SetGetSbpFn(GetSliceGradOpSbpSignature) - .SetInputArgModifyFn(InferSliceGradInputArgModifier); - -REGISTER_USER_OP("logical_slice_assign") - .Input("ref") - .Input("value") - .Attr>("start") - .Attr>("stop") - .Attr>("step") - .SetTensorDescInferFn(InferLogicalSliceAssignTensorDesc) - .SetDataTypeInferFn(InferLogicalSliceAssignDataType) - .SetGetSbpFn(GetLogicalSliceAssignSbpSignatures) - .SetInputArgModifyFn(InferLogicalSliceAssignInputArgModifier); - -REGISTER_USER_OP("logical_slice") - .Input("x") - .Output("y") - .Attr>("start") - .Attr>("stop") - .Attr>("step") - .SetTensorDescInferFn(InferLogicalSliceTensorDesc) - .SetDataTypeInferFn(InferLogicalSliceDataType) - .SetGetSbpFn(GetLogicalSliceSbpSignatures); - REGISTER_USER_OP_GRAD("slice").SetGenBackwardOpConfFn(GenSliceGradOp); - -REGISTER_USER_OP("slice_update") - .Input("x") - .Input("update") - .Output("y") - .Attr>("start") - .Attr>("stop") - .Attr>("step") - .SetTensorDescInferFn(InferSliceUpdateOpTensorDesc) - .SetDataTypeInferFn(InferSliceUpdateOpDataType) - .SetGetSbpFn(GetSliceUpdateOpSbpSignature); - REGISTER_USER_OP_GRAD("slice_update").SetBackwardOpConfGenFn(GenSliceUpdateGradOp); } // namespace oneflow diff --git a/oneflow/user/ops/smooth_l1_loss_op.cpp b/oneflow/user/ops/smooth_l1_loss_op.cpp index 67c110f526e..025895cb2d7 100644 --- a/oneflow/user/ops/smooth_l1_loss_op.cpp +++ b/oneflow/user/ops/smooth_l1_loss_op.cpp @@ -15,12 +15,18 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/loss_op_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -Maybe InferTensorDescFn(user_op::InferContext* ctx) { +/*static*/ Maybe SmoothL1LossOp::GetSbp(user_op::SbpContext* ctx) { + const auto& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); + FOR_RANGE(int64_t, i, 0, input_shape.NumAxes()) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SmoothL1LossOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const auto& input_desc = ctx->InputTensorDesc("input", 0); const auto& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.is_dynamic(), target_desc.is_dynamic()); @@ -33,8 +39,10 @@ Maybe InferTensorDescFn(user_op::InferContext* ctx) { return Maybe::Ok(); } - -Maybe InferDataType(user_op::InferContext* ctx) { +/*static*/ Maybe SmoothL1LossOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SmoothL1LossOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type()); @@ -43,8 +51,27 @@ Maybe InferDataType(user_op::InferContext* ctx) { return Maybe::Ok(); } +/*static*/ Maybe SmoothL1LossOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); + CHECK_OR_RETURN(target_modifier != nullptr); + target_modifier->set_requires_grad(false); + return Maybe::Ok(); +} -Maybe InferGradTensorDescFn(user_op::InferContext* ctx) { +/*static*/ Maybe SmoothL1LossGradOp::GetSbp(user_op::SbpContext* ctx) { + const auto& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); + FOR_RANGE(int64_t, i, 0, input_shape.NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("input", 0), i) + .Split(user_op::OpArg("target", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SmoothL1LossGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const auto& input_desc = ctx->InputTensorDesc("input", 0); const auto& target_desc = ctx->InputTensorDesc("target", 0); const auto& dy_desc = ctx->InputTensorDesc("dy", 0); @@ -60,8 +87,10 @@ Maybe InferGradTensorDescFn(user_op::InferContext* ctx) { return Maybe::Ok(); } - -Maybe InferGradDataType(user_op::InferContext* ctx) { +/*static*/ Maybe SmoothL1LossGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SmoothL1LossGradOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& input_desc = ctx->InputTensorDesc("input", 0); const user_op::TensorDesc& target_desc = ctx->InputTensorDesc("target", 0); CHECK_EQ_OR_RETURN(input_desc.data_type(), target_desc.data_type()); @@ -71,50 +100,6 @@ Maybe InferGradDataType(user_op::InferContext* ctx) { return Maybe::Ok(); } -} // namespace -REGISTER_USER_OP("smooth_l1_loss") - .Input("input") - .Input("target") - .Output("out") - .Attr("beta") - .SetTensorDescInferFn(InferTensorDescFn) - .SetInputArgModifyFn([](const user_op::GetInputArgModifier& GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* target_modifier = GetInputArgModifierFn("target", 0); - CHECK_OR_RETURN(target_modifier != nullptr); - target_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetDataTypeInferFn(InferDataType) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); - FOR_RANGE(int64_t, i, 0, input_shape.NumAxes()) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(user_op::OpArg("out", 0), i).Build(); - } - return Maybe::Ok(); - }); - -REGISTER_USER_OP("smooth_l1_loss_grad") - .Input("input") - .Input("target") - .Input("dy") - .Output("dx") - .Attr("beta") - .SetTensorDescInferFn(InferGradTensorDescFn) - .SetDataTypeInferFn(InferGradDataType) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto& input_shape = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0).shape(); - FOR_RANGE(int64_t, i, 0, input_shape.NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("input", 0), i) - .Split(user_op::OpArg("target", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Build(); - } - return Maybe::Ok(); - }); - REGISTER_USER_OP_GRAD("smooth_l1_loss") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, const user_op::AddOpFn& AddOp) -> Maybe { diff --git a/oneflow/user/ops/softmax_cross_entropy_op.cpp b/oneflow/user/ops/softmax_cross_entropy_op.cpp index 85836b93dac..aa42ab0ee40 100644 --- a/oneflow/user/ops/softmax_cross_entropy_op.cpp +++ b/oneflow/user/ops/softmax_cross_entropy_op.cpp @@ -14,104 +14,102 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("softmax_cross_entropy") - .Input("prediction") - .Input("label") - .Output("prob") //'prob' is just for compute prediction's grad, prob's grad will be ignored - .Output("out") - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("label", 0); - cond_arg_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); - const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); - CHECK_EQ_OR_RETURN(prediction_desc.is_dynamic(), label_desc.is_dynamic()); - CHECK_GE_OR_RETURN(prediction_desc.shape().NumAxes(), 2); - CHECK_EQ_OR_RETURN(label_desc.shape(), prediction_desc.shape()); - const int64_t num_out_axes = prediction_desc.shape().NumAxes() - 1; - DimVector out_dim_vector; - FOR_RANGE(int64_t, i, 0, num_out_axes) { - out_dim_vector.emplace_back(prediction_desc.shape().At(i)); - } - *ctx->OutputShape("prob", 0) = ctx->InputShape("prediction", 0); - *ctx->OutputIsDynamic("prob", 0) = ctx->InputIsDynamic("prediction", 0); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_is_dynamic() = prediction_desc.is_dynamic(); - *out_desc->mut_shape() = Shape(out_dim_vector); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - // ctx->LogicalTensorDesc4InputArgNameAndIndex("out", 0) is not initialized here - const auto num_out_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("prediction", 0).shape().NumAxes() - 1; - FOR_RANGE(int64_t, i, 0, num_out_axes) { - ctx->NewBuilder() - .Split(user_op::OpArg("prediction", 0), i) - .Split(user_op::OpArg("label", 0), i) - .Split(user_op::OpArg("prob", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); - const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); - CHECK_EQ_OR_RETURN(label_desc.data_type(), prediction_desc.data_type()); - *ctx->OutputDType("prob", 0) = ctx->InputDType("prediction", 0); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_data_type() = prediction_desc.data_type(); - return Maybe::Ok(); - }); +/*static*/ Maybe SoftmaxCrossEntropyOp::GetSbp(user_op::SbpContext* ctx) { + // ctx->LogicalTensorDesc4InputArgNameAndIndex("out", 0) is not initialized here + const auto num_out_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("prediction", 0).shape().NumAxes() - 1; + FOR_RANGE(int64_t, i, 0, num_out_axes) { + ctx->NewBuilder() + .Split(user_op::OpArg("prediction", 0), i) + .Split(user_op::OpArg("label", 0), i) + .Split(user_op::OpArg("prob", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SoftmaxCrossEntropyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); + const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); + CHECK_EQ_OR_RETURN(prediction_desc.is_dynamic(), label_desc.is_dynamic()); + CHECK_GE_OR_RETURN(prediction_desc.shape().NumAxes(), 2); + CHECK_EQ_OR_RETURN(label_desc.shape(), prediction_desc.shape()); + const int64_t num_out_axes = prediction_desc.shape().NumAxes() - 1; + DimVector out_dim_vector; + FOR_RANGE(int64_t, i, 0, num_out_axes) { + out_dim_vector.emplace_back(prediction_desc.shape().At(i)); + } + *ctx->OutputShape("prob", 0) = ctx->InputShape("prediction", 0); + *ctx->OutputIsDynamic("prob", 0) = ctx->InputIsDynamic("prediction", 0); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_is_dynamic() = prediction_desc.is_dynamic(); + *out_desc->mut_shape() = Shape(out_dim_vector); + return Maybe::Ok(); +} +/*static*/ Maybe SoftmaxCrossEntropyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SoftmaxCrossEntropyOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& prediction_desc = ctx->InputTensorDesc("prediction", 0); + const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); + CHECK_EQ_OR_RETURN(label_desc.data_type(), prediction_desc.data_type()); + *ctx->OutputDType("prob", 0) = ctx->InputDType("prediction", 0); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_data_type() = prediction_desc.data_type(); + return Maybe::Ok(); +} +/*static*/ Maybe SoftmaxCrossEntropyOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("label", 0); + cond_arg_modifier->set_requires_grad(false); + return Maybe::Ok(); +} -REGISTER_USER_OP("softmax_cross_entropy_grad") - .Input("dy") - .Input("label") - .Input("prob") - .Output("prediction_diff") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& prob_desc = ctx->InputTensorDesc("prob", 0); - const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); - const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); - CHECK_EQ_OR_RETURN(prob_desc.is_dynamic(), label_desc.is_dynamic()); - CHECK_GE_OR_RETURN(prob_desc.shape().NumAxes(), 2); - CHECK_EQ_OR_RETURN(dy_desc.shape().NumAxes(), prob_desc.shape().NumAxes() - 1); - FOR_RANGE(int64_t, i, 0, dy_desc.shape().NumAxes()) { - CHECK_EQ_OR_RETURN(dy_desc.shape().At(i), label_desc.shape().At(i)); - } - CHECK_EQ_OR_RETURN(label_desc.shape(), prob_desc.shape()); - *ctx->OutputShape("prediction_diff", 0) = ctx->InputShape("prob", 0); - *ctx->OutputIsDynamic("prediction_diff", 0) = ctx->InputIsDynamic("prob", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto num_dy_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0).shape().NumAxes(); - FOR_RANGE(int64_t, i, 0, num_dy_axes) { - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("label", 0), i) - .Split(user_op::OpArg("prob", 0), i) - .Split(user_op::OpArg("prediction_diff", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& prob_desc = ctx->InputTensorDesc("prob", 0); - const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); - const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); - CHECK_EQ_OR_RETURN(label_desc.data_type(), prob_desc.data_type()); - CHECK_EQ_OR_RETURN(dy_desc.data_type(), prob_desc.data_type()); - *ctx->OutputDType("prediction_diff", 0) = ctx->InputDType("prob", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe SoftmaxCrossEntropyGradOp::GetSbp(user_op::SbpContext* ctx) { + const auto num_dy_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("dy", 0).shape().NumAxes(); + FOR_RANGE(int64_t, i, 0, num_dy_axes) { + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("label", 0), i) + .Split(user_op::OpArg("prob", 0), i) + .Split(user_op::OpArg("prediction_diff", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SoftmaxCrossEntropyGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& prob_desc = ctx->InputTensorDesc("prob", 0); + const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); + const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); + CHECK_EQ_OR_RETURN(prob_desc.is_dynamic(), label_desc.is_dynamic()); + CHECK_GE_OR_RETURN(prob_desc.shape().NumAxes(), 2); + CHECK_EQ_OR_RETURN(dy_desc.shape().NumAxes(), prob_desc.shape().NumAxes() - 1); + FOR_RANGE(int64_t, i, 0, dy_desc.shape().NumAxes()) { + CHECK_EQ_OR_RETURN(dy_desc.shape().At(i), label_desc.shape().At(i)); + } + CHECK_EQ_OR_RETURN(label_desc.shape(), prob_desc.shape()); + *ctx->OutputShape("prediction_diff", 0) = ctx->InputShape("prob", 0); + *ctx->OutputIsDynamic("prediction_diff", 0) = ctx->InputIsDynamic("prob", 0); + return Maybe::Ok(); +} +/*static*/ Maybe SoftmaxCrossEntropyGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SoftmaxCrossEntropyGradOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& prob_desc = ctx->InputTensorDesc("prob", 0); + const user_op::TensorDesc& label_desc = ctx->InputTensorDesc("label", 0); + const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); + CHECK_EQ_OR_RETURN(label_desc.data_type(), prob_desc.data_type()); + CHECK_EQ_OR_RETURN(dy_desc.data_type(), prob_desc.data_type()); + *ctx->OutputDType("prediction_diff", 0) = ctx->InputDType("prob", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("softmax_cross_entropy") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/softmax_op.cpp b/oneflow/user/ops/softmax_op.cpp index e4c8ad8b730..d460508d783 100644 --- a/oneflow/user/ops/softmax_op.cpp +++ b/oneflow/user/ops/softmax_op.cpp @@ -14,61 +14,61 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ Maybe SoftmaxOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes() - 1) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), axis) + .Split(user_op::OpArg("out", 0), axis) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SoftmaxOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe SoftmaxOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SoftmaxOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("softmax") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, axis, 0, in_tensor.shape().NumAxes() - 1) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), axis) - .Split(user_op::OpArg("out", 0), axis) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe SoftmaxGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); + FOR_RANGE(int64_t, axis, 0, y_tensor.shape().NumAxes() - 1) { + ctx->NewBuilder() + .Split(user_op::OpArg("y", 0), axis) + .Split(user_op::OpArg("dy", 0), axis) + .Split(user_op::OpArg("dx", 0), axis) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SoftmaxGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& y_shape = ctx->InputShape("y", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(dy_shape == y_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} +/*static*/ Maybe SoftmaxGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SoftmaxGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("y", 0), ctx->InputDType("dy", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("y", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("softmax_grad") - .Input("y") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& y_shape = ctx->InputShape("y", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(dy_shape == y_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("y", 0), ctx->InputDType("dy", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("y", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& y_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("y", 0); - FOR_RANGE(int64_t, axis, 0, y_tensor.shape().NumAxes() - 1) { - ctx->NewBuilder() - .Split(user_op::OpArg("y", 0), axis) - .Split(user_op::OpArg("dy", 0), axis) - .Split(user_op::OpArg("dx", 0), axis) - .Build(); - } - return Maybe::Ok(); - }); +namespace { REGISTER_USER_OP_GRAD("softmax").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/softsign_op.cpp b/oneflow/user/ops/softsign_op.cpp index 3249803029e..9cbc34d8cdf 100644 --- a/oneflow/user/ops/softsign_op.cpp +++ b/oneflow/user/ops/softsign_op.cpp @@ -14,61 +14,58 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ Maybe SoftsignOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SoftsignOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe SoftsignOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SoftsignOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("softsign") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe SoftsignGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("dx", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SoftsignGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x_shape = ctx->InputShape("x", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK(dy_shape == x_shape); + *dx_shape = dy_shape; + return Maybe::Ok(); +} +/*static*/ Maybe SoftsignGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SoftsignGradOp::InferDataType(user_op::InferContext* ctx) { + CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("softsign_grad") - .Input("x") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x_shape = ctx->InputShape("x", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK(dy_shape == x_shape); - *dx_shape = dy_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_EQ_OR_RETURN(ctx->InputDType("dy", 0), ctx->InputDType("x", 0)); - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +namespace { REGISTER_USER_OP_GRAD("softsign").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) { const auto softsign_grad_op_name = ctx->FwOp().op_name() + "_grad"; diff --git a/oneflow/user/ops/sort_op.cpp b/oneflow/user/ops/sort_op.cpp index 469533b88ae..cbcbaa07e48 100644 --- a/oneflow/user/ops/sort_op.cpp +++ b/oneflow/user/ops/sort_op.cpp @@ -14,35 +14,35 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("sort") - .Input("in") - .Output("out") - .Attr("direction") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - // The current implementation can only do sort in the last dimension and should use Broadcast - // (by default) instead of Split for that dimension - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - return Maybe::Ok(); - }) - .SetCheckAttrFn([](const user_op::UserOpDefWrapper& op_def, - const user_op::UserOpConfWrapper& op_conf) -> Maybe { - const std::string& direction = op_conf.attr("direction"); - CHECK_OR_RETURN(direction == "ASCENDING" || direction == "DESCENDING"); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe SortOp::GetSbp(user_op::SbpContext* ctx) { + // The current implementation can only do sort in the last dimension and should use Broadcast + // (by default) instead of Split for that dimension + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SortOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe SortOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SortOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} +/*static*/ Maybe SortOp::CheckAttr(const user_op::UserOpDefWrapper&, + const user_op::UserOpConfWrapper& op_conf) { + const std::string& direction = op_conf.attr("direction"); + CHECK_OR_RETURN(direction == "ASCENDING" || direction == "DESCENDING"); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/sparse_cross_entropy_op.cpp b/oneflow/user/ops/sparse_cross_entropy_op.cpp index ff4ea687b00..adf9acdebfd 100644 --- a/oneflow/user/ops/sparse_cross_entropy_op.cpp +++ b/oneflow/user/ops/sparse_cross_entropy_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -71,7 +72,52 @@ Maybe InferDataTypeGrad(user_op::InferContext* ctx) { return Maybe::Ok(); } -Maybe AddMsSignature(user_op::SbpContext* ctx) { +Maybe GenBackwardOpConf4SparseCrossEntropy(const std::string& op_type_name, + const user_op::UserOpWrapper& op, + const user_op::AddOpFn& AddOp) { + if (op.NeedGenGradTensor4OpInput("prediction", 0)) { + user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); + user_op::UserOpConfWrapper grad_op = builder.Op(op_type_name) + .Input("prediction", op.input("prediction", 0)) + .Input("label", op.input("label", 0)) + .Input("dy", op.GetGradTensorWithOpOutput("out", 0)) + .Output("prediction_diff") + .Attr("depth", op.attr("depth")) + .Build(); + op.BindGradTensorWithOpInput(grad_op.output("prediction_diff", 0), "prediction", 0); + AddOp(grad_op); + } + return Maybe::Ok(); +} + +} // namespace + +/*static*/ Maybe SparseCrossEntropyOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("prediction", 0), 0) + .Split(user_op::OpArg("label", 0), 0) + .Split(user_op::OpArg("out", 0), 0) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe SparseCrossEntropyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDescFn(ctx); +} +/*static*/ Maybe SparseCrossEntropyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SparseCrossEntropyOp::InferDataType(user_op::InferContext* ctx) { + return oneflow::InferDataType(ctx); +} +/*static*/ Maybe SparseCrossEntropyOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); + CHECK_OR_RETURN(label_modifier != nullptr); + label_modifier->set_requires_grad(false); + return Maybe::Ok(); +} + +/*static*/ Maybe SparseCrossEntropyMsOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& prediction = ctx->LogicalTensorDesc4InputArgNameAndIndex("prediction", 0); ctx->NewBuilder() @@ -86,17 +132,45 @@ Maybe AddMsSignature(user_op::SbpContext* ctx) { .Build(); return Maybe::Ok(); } +/*static*/ Maybe SparseCrossEntropyMsOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferTensorDescFn(ctx); +} +/*static*/ Maybe SparseCrossEntropyMsOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SparseCrossEntropyMsOp::InferDataType(user_op::InferContext* ctx) { + return oneflow::InferDataType(ctx); +} +/*static*/ Maybe SparseCrossEntropyMsOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); + CHECK_OR_RETURN(label_modifier != nullptr); + label_modifier->set_requires_grad(false); + return Maybe::Ok(); +} -Maybe AddSignature(user_op::SbpContext* ctx) { +/*static*/ Maybe SparseCrossEntropyGradOp::GetSbp(user_op::SbpContext* ctx) { ctx->NewBuilder() .Split(user_op::OpArg("prediction", 0), 0) .Split(user_op::OpArg("label", 0), 0) - .Split(user_op::OpArg("out", 0), 0) + .Split(user_op::OpArg("dy", 0), 0) + .Split(user_op::OpArg("prediction_diff", 0), 0) .Build(); return Maybe::Ok(); } +/*static*/ Maybe SparseCrossEntropyGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return InferGradTensorDescFn(ctx); +} +/*static*/ Maybe SparseCrossEntropyGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SparseCrossEntropyGradOp::InferDataType(user_op::InferContext* ctx) { + return InferDataTypeGrad(ctx); +} -Maybe AddGradMsSignature(user_op::SbpContext* ctx) { +/*static*/ Maybe SparseCrossEntropyMsGradOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& prediction = ctx->LogicalTensorDesc4InputArgNameAndIndex("prediction", 0); ctx->NewBuilder() @@ -113,76 +187,18 @@ Maybe AddGradMsSignature(user_op::SbpContext* ctx) { .Build(); return Maybe::Ok(); } - -Maybe AddGradSignature(user_op::SbpContext* ctx) { - ctx->NewBuilder() - .Split(user_op::OpArg("prediction", 0), 0) - .Split(user_op::OpArg("label", 0), 0) - .Split(user_op::OpArg("dy", 0), 0) - .Split(user_op::OpArg("prediction_diff", 0), 0) - .Build(); - return Maybe::Ok(); +/*static*/ Maybe SparseCrossEntropyMsGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return InferGradTensorDescFn(ctx); } - -template (*GetSbpSignature)(user_op::SbpContext*)> -Maybe GetSbpFn(user_op::SbpContext* ctx) { - JUST(GetSbpSignature(ctx)); - return Maybe::Ok(); +/*static*/ Maybe SparseCrossEntropyMsGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); } - -Maybe GenBackwardOpConf4SparseCrossEntropy(const std::string& op_type_name, - const user_op::UserOpWrapper& op, - user_op::AddOpFn AddOp) { - if (op.NeedGenGradTensor4OpInput("prediction", 0)) { - user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); - user_op::UserOpConfWrapper grad_op = builder.Op(op_type_name) - .Input("prediction", op.input("prediction", 0)) - .Input("label", op.input("label", 0)) - .Input("dy", op.GetGradTensorWithOpOutput("out", 0)) - .Output("prediction_diff") - .Attr("depth", op.attr("depth")) - .Build(); - op.BindGradTensorWithOpInput(grad_op.output("prediction_diff", 0), "prediction", 0); - AddOp(grad_op); - } - return Maybe::Ok(); +/*static*/ Maybe SparseCrossEntropyMsGradOp::InferDataType(user_op::InferContext* ctx) { + return InferDataTypeGrad(ctx); } -} // namespace - -#define REGISTER_SPAESE_CROSS_ENTROPY_USER_OP(op_name, sbp_sig) \ - REGISTER_USER_OP(op_name) \ - .Input("prediction") \ - .Input("label") \ - .Output("out") \ - .Attr("depth") \ - .SetTensorDescInferFn(InferTensorDescFn) \ - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, \ - const user_op::UserOpConfWrapper&) -> Maybe { \ - user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); \ - CHECK_OR_RETURN(label_modifier != nullptr); \ - label_modifier->set_requires_grad(false); \ - return Maybe::Ok(); \ - }) \ - .SetGetSbpFn(GetSbpFn) \ - .SetDataTypeInferFn(InferDataType); - -#define REGISTER_SPAESE_CROSS_ENTROPY_GRAD_USER_OP(op_name, sbp_sig) \ - REGISTER_USER_OP(op_name) \ - .Input("prediction") \ - .Input("label") \ - .Input("dy") \ - .Output("prediction_diff") \ - .Attr("depth") \ - .SetTensorDescInferFn(InferGradTensorDescFn) \ - .SetGetSbpFn(GetSbpFn) \ - .SetDataTypeInferFn(InferDataTypeGrad); - -REGISTER_SPAESE_CROSS_ENTROPY_USER_OP("sparse_cross_entropy", AddSignature); -REGISTER_SPAESE_CROSS_ENTROPY_USER_OP("sparse_cross_entropy_ms", AddMsSignature); -REGISTER_SPAESE_CROSS_ENTROPY_GRAD_USER_OP("sparse_cross_entropy_grad", AddGradSignature); -REGISTER_SPAESE_CROSS_ENTROPY_GRAD_USER_OP("sparse_cross_entropy_ms_grad", AddGradMsSignature); - REGISTER_USER_OP_GRAD("sparse_cross_entropy") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp b/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp index e5df89b6148..5550e1caae8 100644 --- a/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp +++ b/oneflow/user/ops/sparse_softmax_cross_entropy_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -155,41 +156,47 @@ Maybe GenBackwardOpConf4SparseSoftmaxCrossEntropy(const std::string& op_ty } // namespace -#define REGISTER_SPAESE_SOFTMAX_CROSS_ENTROPY_USER_OP(op_name, sbp_sig) \ - REGISTER_USER_OP(op_name) \ - .Input("prediction") \ - .Input("label") \ - .Output("prob") \ - .Output("out") \ - .Attr("depth") \ - .SetTensorDescInferFn(InferTensorDescFn) \ - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, \ - const user_op::UserOpConfWrapper&) -> Maybe { \ - user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); \ - CHECK_OR_RETURN(label_modifier != nullptr); \ - label_modifier->set_requires_grad(false); \ - return Maybe::Ok(); \ - }) \ - .SetGetSbpFn(GetSbpFn) \ - .SetDataTypeInferFn(InferDataType); - -#define REGISTER_SPAESE_SOFTMAX_CROSS_ENTROPY_GRAD_USER_OP(op_name, sbp_sig) \ - REGISTER_USER_OP(op_name) \ - .Input("label") \ - .Input("dy") \ - .Input("prob") \ - .Output("prediction_diff") \ - .Attr("depth") \ - .SetTensorDescInferFn(InferGradTensorDescFn) \ - .SetGetSbpFn(GetSbpFn) \ - .SetDataTypeInferFn(InferDataTypeGrad); - -REGISTER_SPAESE_SOFTMAX_CROSS_ENTROPY_USER_OP("sparse_softmax_cross_entropy", AddSignature); -REGISTER_SPAESE_SOFTMAX_CROSS_ENTROPY_USER_OP("sparse_softmax_cross_entropy_ms", AddMsSignature); -REGISTER_SPAESE_SOFTMAX_CROSS_ENTROPY_GRAD_USER_OP("sparse_softmax_cross_entropy_grad", - AddGradSignature); -REGISTER_SPAESE_SOFTMAX_CROSS_ENTROPY_GRAD_USER_OP("sparse_softmax_cross_entropy_ms_grad", - AddGradMsSignature); +#define IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_OP_FUNCS(op_name, sbp_sig) \ + /*static*/ Maybe op_name##Op::GetSbp(user_op::SbpContext* ctx) { return sbp_sig(ctx); } \ + /*static*/ Maybe op_name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return InferTensorDescFn(ctx); \ + } \ + /*static*/ Maybe op_name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe op_name##Op::InferDataType(user_op::InferContext* ctx) { \ + return oneflow::InferDataType(ctx); \ + } \ + /*static*/ Maybe op_name##Op::ModifyInputArg( \ + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { \ + user_op::InputArgModifier* label_modifier = GetInputArgModifierFn("label", 0); \ + CHECK_OR_RETURN(label_modifier != nullptr); \ + label_modifier->set_requires_grad(false); \ + return Maybe::Ok(); \ + } + +IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_OP_FUNCS(SparseSoftmaxCrossEntropy, AddSignature); +IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_OP_FUNCS(SparseSoftmaxCrossEntropyMs, AddMsSignature); +#undef IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_OP_FUNCS + +#define IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_GRAD_OP_FUNCS(op_name, sbp_sig) \ + /*static*/ Maybe op_name##GradOp::GetSbp(user_op::SbpContext* ctx) { \ + return sbp_sig(ctx); \ + } \ + /*static*/ Maybe op_name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return InferGradTensorDescFn(ctx); \ + } \ + /*static*/ Maybe op_name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe op_name##GradOp::InferDataType(user_op::InferContext* ctx) { \ + return InferDataTypeGrad(ctx); \ + } + +IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_GRAD_OP_FUNCS(SparseSoftmaxCrossEntropy, AddGradSignature); +IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_GRAD_OP_FUNCS(SparseSoftmaxCrossEntropyMs, + AddGradMsSignature); +#undef IMPLEMENT_SPAESE_SOFTMAX_CROSS_ENTROPY_GRAD_OP_FUNCS REGISTER_USER_OP_GRAD("sparse_softmax_cross_entropy") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/split_like_op.cpp b/oneflow/user/ops/split_like_op.cpp index 80f98d49404..a0e31349603 100644 --- a/oneflow/user/ops/split_like_op.cpp +++ b/oneflow/user/ops/split_like_op.cpp @@ -14,12 +14,54 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -Maybe InferTensorDesc(user_op::InferContext* ctx) { +/*static*/ Maybe SplitLikeOp::GetSbp(user_op::SbpContext* ctx) { + const auto axis = ctx->Attr("axis"); + const int64_t in_num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape().NumAxes(); + const int64_t like_num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape().NumAxes(); + FOR_RANGE(int64_t, i, 0, like_num_axes) { + if (i == axis) { continue; } + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + std::vector like_arg_vec; + const size_t like_arg_size = ctx->outputs().size(); + like_arg_vec.reserve(like_arg_size); + FOR_RANGE(int32_t, i, 0, like_arg_size) { like_arg_vec.emplace_back("like", i); } + FOR_RANGE(int64_t, i, like_num_axes, in_num_axes) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .Broadcast(like_arg_vec) + .Split(ctx->outputs(), i) + .Build(); + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .PartialSum(like_arg_vec) + .Split(ctx->outputs(), i) + .Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(like_arg_vec) + .PartialSum(ctx->outputs()) + .Build(); + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .Broadcast(like_arg_vec) + .PartialSum(ctx->outputs()) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("in", 0)) + .PartialSum(like_arg_vec) + .Broadcast(ctx->outputs()) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe SplitLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { const auto axis = ctx->Attr("axis"); const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); int64_t dynamic_dim_size = 0; @@ -57,8 +99,10 @@ Maybe InferTensorDesc(user_op::InferContext* ctx) { } return Maybe::Ok(); } - -Maybe InferDataType(user_op::InferContext* ctx) { +/*static*/ Maybe SplitLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SplitLikeOp::InferDataType(user_op::InferContext* ctx) { const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); FOR_RANGE(int32_t, i, 0, ctx->outputs().size()) { user_op::TensorDesc* out_i_desc = ctx->OutputTensorDesc("out", i); @@ -66,9 +110,8 @@ Maybe InferDataType(user_op::InferContext* ctx) { } return Maybe::Ok(); } - -Maybe SetLikeArgModifier(user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper& user_op_conf) { +/*static*/ Maybe SplitLikeOp::ModifyInputArg(const GetInputArgModifier& GetInputArgModifierFn, + const user_op::UserOpConfWrapper& user_op_conf) { FOR_RANGE(int32_t, i, 0, user_op_conf.input_size("like")) { user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", i); CHECK_NOTNULL_OR_RETURN(like_modifier); @@ -77,50 +120,15 @@ Maybe SetLikeArgModifier(user_op::GetInputArgModifier GetInputArgModifierF return Maybe::Ok(); } -Maybe GetSbpSignature(user_op::SbpContext* ctx) { - const auto axis = ctx->Attr("axis"); - const int64_t in_num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0).shape().NumAxes(); - const int64_t like_num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0).shape().NumAxes(); - FOR_RANGE(int64_t, i, 0, like_num_axes) { - if (i == axis) { continue; } - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - std::vector like_arg_vec; - const size_t like_arg_size = ctx->outputs().size(); - like_arg_vec.reserve(like_arg_size); - FOR_RANGE(int32_t, i, 0, like_arg_size) { like_arg_vec.emplace_back("like", i); } - FOR_RANGE(int64_t, i, like_num_axes, in_num_axes) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Broadcast(like_arg_vec) - .Split(ctx->outputs(), i) - .Build(); - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .PartialSum(like_arg_vec) - .Split(ctx->outputs(), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(like_arg_vec) - .PartialSum(ctx->outputs()) - .Build(); - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .Broadcast(like_arg_vec) - .PartialSum(ctx->outputs()) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("in", 0)) - .PartialSum(like_arg_vec) - .Broadcast(ctx->outputs()) - .Build(); +/*static*/ Maybe SplitLikeOp::CheckAttr(const user_op::UserOpDefWrapper&, + const user_op::UserOpConfWrapper& op_conf) { + CHECK_OR_RETURN(op_conf.input_size("like") >= 2); + CHECK_OR_RETURN(op_conf.output_size("out") >= 2); return Maybe::Ok(); } +namespace { + Maybe GenGradOp(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) { const int64_t axis = op.attr("axis"); const int32_t out_size = op.output_size("out"); @@ -158,16 +166,6 @@ Maybe GenGradOp(const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) } // namespace -REGISTER_USER_OP("split_like") - .Input("in") - .InputWithMinimum("like", 2) - .OutputWithMinimum("out", 2) - .Attr("axis") - .SetTensorDescInferFn(InferTensorDesc) - .SetInputArgModifyFn(SetLikeArgModifier) - .SetGetSbpFn(GetSbpSignature) - .SetDataTypeInferFn(InferDataType); - REGISTER_USER_OP_GRAD("split_like").SetGenBackwardOpConfFn(GenGradOp); } // namespace oneflow diff --git a/oneflow/user/ops/square_sum_op.cpp b/oneflow/user/ops/square_sum_op.cpp index 494688fd9ce..c97d3219046 100644 --- a/oneflow/user/ops/square_sum_op.cpp +++ b/oneflow/user/ops/square_sum_op.cpp @@ -14,61 +14,62 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("square_sum") - .Input("x") - .Output("y") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); - *y->mut_shape() = Shape({1}); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const int64_t num_x_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape().NumAxes(); - FOR_RANGE(int64_t, i, 0, num_x_axes) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .PartialSum(user_op::OpArg("y", 0)) - .Build(); - } - return Maybe::Ok(); - }); - -REGISTER_USER_OP("multi_square_sum") - .InputWithMinimum("x", 1) - .Output("y") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); - *y->mut_shape() = Shape({1}); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_0 = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); - for (int64_t i = 1; i < ctx->input_size("x"); ++i) { - const user_op::TensorDesc& x_i = ctx->InputTensorDesc("x", i); - CHECK_EQ_OR_RETURN(x_i.data_type(), x_0.data_type()); - } - *y->mut_data_type() = x_0.data_type(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - int64_t min_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape().NumAxes(); - for (int64_t i = 1; i < ctx->user_op_conf().input_size("x"); ++i) { - min_num_axes = std::min( - min_num_axes, ctx->LogicalTensorDesc4InputArgNameAndIndex("x", i).shape().NumAxes()); - } - for (int64_t i = 0; i < min_num_axes; ++i) { - ctx->NewBuilder().Split(ctx->inputs(), i).PartialSum(user_op::OpArg("y", 0)).Build(); - } - return Maybe::Ok(); - }); +/*static*/ Maybe SquareSumOp::GetSbp(user_op::SbpContext* ctx) { + const int64_t num_x_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape().NumAxes(); + FOR_RANGE(int64_t, i, 0, num_x_axes) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).PartialSum(user_op::OpArg("y", 0)).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SquareSumOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + *y->mut_shape() = Shape({1}); + return Maybe::Ok(); +} +/*static*/ Maybe SquareSumOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SquareSumOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe MultiSquareSumOp::GetSbp(user_op::SbpContext* ctx) { + int64_t min_num_axes = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0).shape().NumAxes(); + for (int64_t i = 1; i < ctx->user_op_conf().input_size("x"); ++i) { + min_num_axes = std::min(min_num_axes, + ctx->LogicalTensorDesc4InputArgNameAndIndex("x", i).shape().NumAxes()); + } + for (int64_t i = 0; i < min_num_axes; ++i) { + ctx->NewBuilder().Split(ctx->inputs(), i).PartialSum(user_op::OpArg("y", 0)).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe MultiSquareSumOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + *y->mut_shape() = Shape({1}); + return Maybe::Ok(); +} +/*static*/ Maybe MultiSquareSumOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe MultiSquareSumOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_0 = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + for (int64_t i = 1; i < ctx->input_size("x"); ++i) { + const user_op::TensorDesc& x_i = ctx->InputTensorDesc("x", i); + CHECK_EQ_OR_RETURN(x_i.data_type(), x_0.data_type()); + } + *y->mut_data_type() = x_0.data_type(); + return Maybe::Ok(); +} +/*static*/ Maybe MultiSquareSumOp::CheckAttr(const user_op::UserOpDefWrapper&, + const user_op::UserOpConfWrapper& op_conf) { + CHECK_OR_RETURN(op_conf.input_size("x") >= 1); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/squeeze_op.cpp b/oneflow/user/ops/squeeze_op.cpp index 6be95c78159..d6c9cb111a4 100644 --- a/oneflow/user/ops/squeeze_op.cpp +++ b/oneflow/user/ops/squeeze_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -40,21 +41,7 @@ Maybe CheckAndLabelAxesToSqueezeMinusOne(const AxisVector& axes, DimVector } // namespace -Maybe SqueezeTensorDescInferFn(user_op::InferContext* ctx) { - const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); - AxisVector fixed_axes_vec; - JUST(TransformNegativeAxesToPositive(ctx->Attr>("axes"), in_shape.NumAxes(), - &fixed_axes_vec)); - - DimVector dim_vec = in_shape.dim_vec(); - JUST(CheckAndLabelAxesToSqueezeMinusOne(fixed_axes_vec, &dim_vec)); - dim_vec.erase(std::remove(dim_vec.begin(), dim_vec.end(), -1), dim_vec.end()); - *out_shape = Shape(dim_vec); - return Maybe::Ok(); -} - -Maybe SqueezeGetSbpFn(user_op::SbpContext* ctx) { +/*static*/ Maybe SqueezeOp::GetSbp(user_op::SbpContext* ctx) { const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); AxisVector fixed_axes_vec; JUST(TransformNegativeAxesToPositive(ctx->Attr>("axes"), @@ -74,17 +61,26 @@ Maybe SqueezeGetSbpFn(user_op::SbpContext* ctx) { } return Maybe::Ok(); } +/*static*/ Maybe SqueezeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + Shape* out_shape = ctx->OutputShape("out", 0); + AxisVector fixed_axes_vec; + JUST(TransformNegativeAxesToPositive(ctx->Attr>("axes"), in_shape.NumAxes(), + &fixed_axes_vec)); -REGISTER_USER_OP("squeeze") - .Input("in") - .Output("out") - .Attr>("axes") - .SetTensorDescInferFn(SqueezeTensorDescInferFn) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(SqueezeGetSbpFn); + DimVector dim_vec = in_shape.dim_vec(); + JUST(CheckAndLabelAxesToSqueezeMinusOne(fixed_axes_vec, &dim_vec)); + dim_vec.erase(std::remove(dim_vec.begin(), dim_vec.end(), -1), dim_vec.end()); + *out_shape = Shape(dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe SqueezeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SqueezeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("squeeze").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { diff --git a/oneflow/user/ops/ssp_variable_proxy_op.cpp b/oneflow/user/ops/ssp_variable_proxy_op.cpp index 1af7c0c4138..9a5a31262a7 100644 --- a/oneflow/user/ops/ssp_variable_proxy_op.cpp +++ b/oneflow/user/ops/ssp_variable_proxy_op.cpp @@ -14,46 +14,41 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -REGISTER_NO_GRAD_USER_OP("ssp_variable_proxy") - .Input("var") - .Output("ref") - .Output("value") - .Attr("buffer_size", 1) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& var_shape = ctx->InputShape("var", 0); - *ctx->OutputShape("ref", 0) = var_shape; - *ctx->OutputShape("value", 0) = var_shape; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const auto& var_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("var", 0); - FOR_RANGE(int64_t, i, 0, var_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("var", 0), i) - .Split(user_op::OpArg("ref", 0), i) - .Split(user_op::OpArg("value", 0), i) - .Build(); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("ref", 0) = ctx->InputDType("var", 0); - *ctx->OutputDType("value", 0) = ctx->InputDType("var", 0); - return Maybe::Ok(); - }) - .SetOutputArgModifyFn([](user_op::GetOutputArgModifier GetOutputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> Maybe { - user_op::OutputArgModifier* out_modifier = GetOutputArgModifierFn("ref", 0); - CHECK_OR_RETURN(out_modifier != nullptr); - out_modifier->set_is_mutable(true); - return Maybe::Ok(); - }); - -} // namespace +/*static*/ Maybe SspVariableProxyOp::GetSbp(user_op::SbpContext* ctx) { + const auto& var_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("var", 0); + FOR_RANGE(int64_t, i, 0, var_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("var", 0), i) + .Split(user_op::OpArg("ref", 0), i) + .Split(user_op::OpArg("value", 0), i) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe SspVariableProxyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& var_shape = ctx->InputShape("var", 0); + *ctx->OutputShape("ref", 0) = var_shape; + *ctx->OutputShape("value", 0) = var_shape; + return Maybe::Ok(); +} +/*static*/ Maybe SspVariableProxyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SspVariableProxyOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("ref", 0) = ctx->InputDType("var", 0); + *ctx->OutputDType("value", 0) = ctx->InputDType("var", 0); + return Maybe::Ok(); +} +/*static*/ Maybe SspVariableProxyOp::ModifyOutputArg( + const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::OutputArgModifier* out_modifier = GetOutputArgModifierFn("ref", 0); + CHECK_OR_RETURN(out_modifier != nullptr); + out_modifier->set_is_mutable(true); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/summary_ops.cpp b/oneflow/user/ops/summary_ops.cpp index 0026e7820fd..6235856d1fc 100644 --- a/oneflow/user/ops/summary_ops.cpp +++ b/oneflow/user/ops/summary_ops.cpp @@ -14,11 +14,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" -#include "oneflow/core/framework/user_op_attr.pb.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace summary { +namespace { Maybe CheckStepShape(const Shape* step) { CHECK_OR_RETURN(step->elem_cnt() == 1); @@ -37,51 +37,85 @@ Maybe CheckInAndStepScalar(user_op::InferContext* ctx) { return Maybe::Ok(); } -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("create_summary_writer") - .Attr("logdir") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { return Maybe::Ok(); }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("flush_summary_writer") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { return Maybe::Ok(); }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("summary_write_scalar") - .Input("in") - .Input("step") - .Input("tag") - .SetTensorDescInferFn(CheckInAndStepScalar) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { return Maybe::Ok(); }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("summary_write_histogram") - .Input("in") - .Input("step") - .Input("tag") - .SetTensorDescInferFn(CheckStepShapeInCtx) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { return Maybe::Ok(); }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("summary_write_pb") - .Input("in") - .Input("step") - .SetTensorDescInferFn(CheckStepShapeInCtx) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { return Maybe::Ok(); }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("summary_write_image") - .Input("in") - .Input("step") - .Input("tag") - .SetTensorDescInferFn(CheckStepShapeInCtx) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { return Maybe::Ok(); }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); -} // namespace summary +} // namespace + +/*static*/ Maybe CreateSummaryWriterOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe CreateSummaryWriterOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return Maybe::Ok(); +} +/*static*/ Maybe CreateSummaryWriterOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return Maybe::Ok(); +} +/*static*/ Maybe CreateSummaryWriterOp::InferDataType(user_op::InferContext* ctx) { + return Maybe::Ok(); +} + +/*static*/ Maybe FlushSummaryWriterOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe FlushSummaryWriterOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return Maybe::Ok(); +} +/*static*/ Maybe FlushSummaryWriterOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return Maybe::Ok(); +} +/*static*/ Maybe FlushSummaryWriterOp::InferDataType(user_op::InferContext* ctx) { + return Maybe::Ok(); +} + +/*static*/ Maybe SummaryWriteScalarOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe SummaryWriteScalarOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return CheckInAndStepScalar(ctx); +} +/*static*/ Maybe SummaryWriteScalarOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SummaryWriteScalarOp::InferDataType(user_op::InferContext* ctx) { + return Maybe::Ok(); +} + +/*static*/ Maybe SummaryWriteHistogramOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe SummaryWriteHistogramOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return CheckStepShapeInCtx(ctx); +} +/*static*/ Maybe SummaryWriteHistogramOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SummaryWriteHistogramOp::InferDataType(user_op::InferContext* ctx) { + return Maybe::Ok(); +} + +/*static*/ Maybe SummaryWritePbOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe SummaryWritePbOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return CheckStepShapeInCtx(ctx); +} +/*static*/ Maybe SummaryWritePbOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SummaryWritePbOp::InferDataType(user_op::InferContext* ctx) { + return Maybe::Ok(); +} + +/*static*/ Maybe SummaryWriteImageOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe SummaryWriteImageOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return CheckStepShapeInCtx(ctx); +} +/*static*/ Maybe SummaryWriteImageOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe SummaryWriteImageOp::InferDataType(user_op::InferContext* ctx) { + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/tanh_op.cpp b/oneflow/user/ops/tanh_op.cpp index 639b9529e72..caf89a63ac9 100644 --- a/oneflow/user/ops/tanh_op.cpp +++ b/oneflow/user/ops/tanh_op.cpp @@ -14,23 +14,35 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("tanh") - .Input("x") - .Output("y") - .SetTensorDescInferFn(user_op::TensorDescInferFnUtil::Unchanged) - .SetGetSbpFn(user_op::GetSbpFnUtil::SplitForEachAxis) - .SetDataTypeInferFn(user_op::TensorDescInferFnUtil::UnchangedDataType); - -REGISTER_USER_OP((std::string("") + "tanh" + "_grad")) - .Input("x") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn(user_op::TensorDescInferFnUtil::Unchanged) - .SetGetSbpFn(user_op::GetSbpFnUtil::SplitForEachAxis) - .SetDataTypeInferFn(user_op::TensorDescInferFnUtil::UnchangedDataType); +/*static*/ Maybe TanhOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); +} +/*static*/ Maybe TanhOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return user_op::TensorDescInferFnUtil::Unchanged(ctx); +} +/*static*/ Maybe TanhOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TanhOp::InferDataType(user_op::InferContext* ctx) { + return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx); +} + +/*static*/ Maybe TanhGradOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::SplitForEachAxis(ctx); +} +/*static*/ Maybe TanhGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return user_op::TensorDescInferFnUtil::Unchanged(ctx); +} +/*static*/ Maybe TanhGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TanhGradOp::InferDataType(user_op::InferContext* ctx) { + return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx); +} REGISTER_USER_OP_GRAD("tanh").SetGenBackwardOpConfFn( [](const user_op::UserOpWrapper& op, const user_op::AddOpFn& AddOp) -> Maybe { diff --git a/oneflow/user/ops/tensor_buffer_ops.cpp b/oneflow/user/ops/tensor_buffer_ops.cpp index f1e964c50db..80b1c5c99ff 100644 --- a/oneflow/user/ops/tensor_buffer_ops.cpp +++ b/oneflow/user/ops/tensor_buffer_ops.cpp @@ -14,199 +14,197 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { +/*static*/ Maybe TensorBufferToTensorOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe TensorBufferToTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + out->set_is_dynamic(in.is_dynamic()); + const auto& instance_shape = ctx->Attr("instance_shape"); + DimVector dim_vec; + dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin(), in.shape().dim_vec().cend()); + dim_vec.insert(dim_vec.end(), instance_shape.dim_vec().cbegin(), instance_shape.dim_vec().cend()); + *out->mut_shape() = Shape(dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe TensorBufferToTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TensorBufferToTensorOp::InferDataType(user_op::InferContext* ctx) { + const auto data_type = ctx->Attr("dtype"); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(IsPODDataType(data_type)); + *out->mut_data_type() = data_type; + return Maybe::Ok(); +} -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("tensor_buffer_to_tensor") - .Input("in") - .Output("out") - .Attr("instance_shape") - .Attr("dtype") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - out->set_is_dynamic(in.is_dynamic()); - const auto& instance_shape = ctx->Attr("instance_shape"); - DimVector dim_vec; - dim_vec.insert(dim_vec.end(), in.shape().dim_vec().cbegin(), in.shape().dim_vec().cend()); - dim_vec.insert(dim_vec.end(), instance_shape.dim_vec().cbegin(), - instance_shape.dim_vec().cend()); - *out->mut_shape() = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto data_type = ctx->Attr("dtype"); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(IsPODDataType(data_type)); - *out->mut_data_type() = data_type; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }); +/*static*/ Maybe TensorToTensorBufferOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + const auto& instance_dims = ctx->Attr("instance_dims"); + CHECK_LE_OR_RETURN(instance_dims, in.shape().NumAxes()); + FOR_RANGE(int64_t, i, 0, in.shape().NumAxes() - instance_dims) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe TensorToTensorBufferOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + const Shape& in_shape = in.shape(); + const auto& instance_dims = ctx->Attr("instance_dims"); + CHECK_LT_OR_RETURN(instance_dims, in_shape.NumAxes()); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + out->set_is_dynamic(in.is_dynamic()); + DimVector out_dim_vec; + out_dim_vec.insert(out_dim_vec.end(), in_shape.dim_vec().cbegin(), + in_shape.dim_vec().cend() - instance_dims); + *out->mut_shape() = Shape(out_dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe TensorToTensorBufferOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TensorToTensorBufferOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + CHECK_OR_RETURN(IsPODDataType(in.data_type())); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + *out->mut_data_type() = DataType::kTensorBuffer; + return Maybe::Ok(); +} -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("tensor_to_tensor_buffer") - .Input("in") - .Output("out") - .Attr("instance_dims") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - const Shape& in_shape = in.shape(); - const auto& instance_dims = ctx->Attr("instance_dims"); - CHECK_LT_OR_RETURN(instance_dims, in_shape.NumAxes()); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - out->set_is_dynamic(in.is_dynamic()); - DimVector out_dim_vec; - out_dim_vec.insert(out_dim_vec.end(), in_shape.dim_vec().cbegin(), - in_shape.dim_vec().cend() - instance_dims); - *out->mut_shape() = Shape(out_dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - CHECK_OR_RETURN(IsPODDataType(in.data_type())); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - *out->mut_data_type() = DataType::kTensorBuffer; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - const auto& instance_dims = ctx->Attr("instance_dims"); - CHECK_LE_OR_RETURN(instance_dims, in.shape().NumAxes()); - FOR_RANGE(int64_t, i, 0, in.shape().NumAxes() - instance_dims) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - return Maybe::Ok(); - }); +/*static*/ Maybe GenTensorBufferOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe GenTensorBufferOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + const Shape& shape = ctx->Attr("shape"); + const int64_t num_tensor_buffers = shape.elem_cnt(); + const std::vector& shape_list = ctx->Attr>("shape_list"); + const std::vector& value_list = ctx->Attr>("value_list"); + CHECK_EQ_OR_RETURN(num_tensor_buffers, shape_list.size()); + CHECK_EQ_OR_RETURN(num_tensor_buffers, value_list.size()); + *out->mut_shape() = shape; + out->set_is_dynamic(ctx->Attr("dynamic_out")); + return Maybe::Ok(); +} +/*static*/ Maybe GenTensorBufferOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe GenTensorBufferOp::InferDataType(user_op::InferContext* ctx) { + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + *out->mut_data_type() = DataType::kTensorBuffer; + return Maybe::Ok(); +} -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("gen_tensor_buffer") - .Output("out") - .Attr("shape") - .Attr>("shape_list") - .Attr>("value_list") - .Attr("data_type") - .Attr("dynamic_out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - const Shape& shape = ctx->Attr("shape"); - const int64_t num_tensor_buffers = shape.elem_cnt(); - const std::vector& shape_list = ctx->Attr>("shape_list"); - const std::vector& value_list = ctx->Attr>("value_list"); - CHECK_EQ_OR_RETURN(num_tensor_buffers, shape_list.size()); - CHECK_EQ_OR_RETURN(num_tensor_buffers, value_list.size()); - *out->mut_shape() = shape; - out->set_is_dynamic(ctx->Attr("dynamic_out")); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - *out->mut_data_type() = DataType::kTensorBuffer; - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/*static*/ Maybe TensorBufferToListOfTensorsOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe TensorBufferToListOfTensorsOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + CHECK_GT_OR_RETURN(in.shape().elem_cnt(), 0); + CHECK_OR_RETURN(!in.is_dynamic()); + const Shape& out_shape = ctx->Attr("out_shape"); + const bool dynamic_out = ctx->Attr("dynamic_out"); + int64_t num_tensor_buffers = in.shape().elem_cnt(); + for (int64_t i = 0; i < num_tensor_buffers; ++i) { + user_op::TensorDesc* out_i = ctx->OutputTensorDesc("out", i); + *out_i->mut_shape() = out_shape; + out_i->set_is_dynamic(dynamic_out); + } + return Maybe::Ok(); +} +/*static*/ Maybe TensorBufferToListOfTensorsOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TensorBufferToListOfTensorsOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + CHECK_EQ_OR_RETURN(in.data_type(), DataType::kTensorBuffer); + const DataType out_dtype = ctx->Attr("out_dtype"); + CHECK_OR_RETURN(IsPODDataType(out_dtype)); + int64_t num_tensor_buffers = ctx->outputs().size(); + for (int64_t i = 0; i < num_tensor_buffers; ++i) { + user_op::TensorDesc* out_i = ctx->OutputTensorDesc("out", i); + *out_i->mut_data_type() = out_dtype; + } + return Maybe::Ok(); +} +/*static*/ Maybe TensorBufferToListOfTensorsOp::ModifyOutputArg( + const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + if (conf.attr("dynamic_out")) { + FOR_RANGE(int64_t, i, 0, conf.output_size("out")) { + user_op::OutputArgModifier* out_i_modifier = GetOutputArgModifierFn("out", i); + CHECK_OR_RETURN(out_i_modifier != nullptr); + out_i_modifier->set_header_infered_before_compute(false); + } + } + return Maybe::Ok(); +} -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("tensor_buffer_to_list_of_tensors") - .Input("in") - .OutputWithMinimum("out", 1) - .Attr("out_shape") - .Attr("out_dtype") - .Attr("dynamic_out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - CHECK_GT_OR_RETURN(in.shape().elem_cnt(), 0); - CHECK_OR_RETURN(!in.is_dynamic()); - const Shape& out_shape = ctx->Attr("out_shape"); - const bool dynamic_out = ctx->Attr("dynamic_out"); - int64_t num_tensor_buffers = in.shape().elem_cnt(); - for (int64_t i = 0; i < num_tensor_buffers; ++i) { - user_op::TensorDesc* out_i = ctx->OutputTensorDesc("out", i); - *out_i->mut_shape() = out_shape; - out_i->set_is_dynamic(dynamic_out); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - CHECK_EQ_OR_RETURN(in.data_type(), DataType::kTensorBuffer); - const DataType out_dtype = ctx->Attr("out_dtype"); - CHECK_OR_RETURN(IsPODDataType(out_dtype)); - int64_t num_tensor_buffers = ctx->outputs().size(); - for (int64_t i = 0; i < num_tensor_buffers; ++i) { - user_op::TensorDesc* out_i = ctx->OutputTensorDesc("out", i); - *out_i->mut_data_type() = out_dtype; - } - return Maybe::Ok(); - }) - .SetOutputArgModifyFn([](user_op::GetOutputArgModifier GetOutputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> Maybe { - if (conf.attr("dynamic_out")) { - FOR_RANGE(int64_t, i, 0, conf.output_size("out")) { - user_op::OutputArgModifier* out_i_modifier = GetOutputArgModifierFn("out", i); - CHECK_OR_RETURN(out_i_modifier != nullptr); - out_i_modifier->set_header_infered_before_compute(false); - } - } - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/*static*/ Maybe TensorBufferToListOfTensorsOp::CheckAttr( + const user_op::UserOpDefWrapper&, const user_op::UserOpConfWrapper& op_conf) { + CHECK_OR_RETURN(op_conf.output_size("out") >= 1); + return Maybe::Ok(); +} -REGISTER_NO_GRAD_CPU_ONLY_USER_OP("tensor_buffer_to_list_of_tensors_v2") - .Input("in") - .OutputWithMinimum("out", 1) - .Attr>("out_shapes") - .Attr>("out_dtypes") - .Attr("dynamic_out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - CHECK_GT_OR_RETURN(in.shape().elem_cnt(), 0); - CHECK_OR_RETURN(!in.is_dynamic()); - const std::vector& out_shapes = ctx->Attr>("out_shapes"); - const bool dynamic_out = ctx->Attr("dynamic_out"); - int64_t num_tensor_buffers = in.shape().elem_cnt(); - for (int64_t i = 0; i < num_tensor_buffers; ++i) { - user_op::TensorDesc* out_i = ctx->OutputTensorDesc("out", i); - *out_i->mut_shape() = out_shapes[i]; - out_i->set_is_dynamic(dynamic_out); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - CHECK_EQ_OR_RETURN(in.data_type(), DataType::kTensorBuffer); - const std::vector& out_dtypes = ctx->Attr>("out_dtypes"); - int64_t num_tensor_buffers = ctx->outputs().size(); - for (int64_t i = 0; i < num_tensor_buffers; ++i) { - CHECK_OR_RETURN(IsPODDataType(out_dtypes[i])); - user_op::TensorDesc* out_i = ctx->OutputTensorDesc("out", i); - *out_i->mut_data_type() = out_dtypes[i]; - } - return Maybe::Ok(); - }) - .SetOutputArgModifyFn([](user_op::GetOutputArgModifier GetOutputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> Maybe { - if (conf.attr("dynamic_out")) { - FOR_RANGE(int64_t, i, 0, conf.output_size("out")) { - user_op::OutputArgModifier* out_i_modifier = GetOutputArgModifierFn("out", i); - CHECK_OR_RETURN(out_i_modifier != nullptr); - out_i_modifier->set_header_infered_before_compute(false); - } - } - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); - -} // namespace +/*static*/ Maybe TensorBufferToListOfTensorsV2Op::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe TensorBufferToListOfTensorsV2Op::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + CHECK_GT_OR_RETURN(in.shape().elem_cnt(), 0); + CHECK_OR_RETURN(!in.is_dynamic()); + const std::vector& out_shapes = ctx->Attr>("out_shapes"); + const bool dynamic_out = ctx->Attr("dynamic_out"); + int64_t num_tensor_buffers = in.shape().elem_cnt(); + for (int64_t i = 0; i < num_tensor_buffers; ++i) { + user_op::TensorDesc* out_i = ctx->OutputTensorDesc("out", i); + *out_i->mut_shape() = out_shapes[i]; + out_i->set_is_dynamic(dynamic_out); + } + return Maybe::Ok(); +} +/*static*/ Maybe TensorBufferToListOfTensorsV2Op::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TensorBufferToListOfTensorsV2Op::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + CHECK_EQ_OR_RETURN(in.data_type(), DataType::kTensorBuffer); + const std::vector& out_dtypes = ctx->Attr>("out_dtypes"); + int64_t num_tensor_buffers = ctx->outputs().size(); + for (int64_t i = 0; i < num_tensor_buffers; ++i) { + CHECK_OR_RETURN(IsPODDataType(out_dtypes[i])); + user_op::TensorDesc* out_i = ctx->OutputTensorDesc("out", i); + *out_i->mut_data_type() = out_dtypes[i]; + } + return Maybe::Ok(); +} +/*static*/ Maybe TensorBufferToListOfTensorsV2Op::ModifyOutputArg( + const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper& conf) { + if (conf.attr("dynamic_out")) { + FOR_RANGE(int64_t, i, 0, conf.output_size("out")) { + user_op::OutputArgModifier* out_i_modifier = GetOutputArgModifierFn("out", i); + CHECK_OR_RETURN(out_i_modifier != nullptr); + out_i_modifier->set_header_infered_before_compute(false); + } + } + return Maybe::Ok(); +} +/*static*/ Maybe TensorBufferToListOfTensorsV2Op::CheckAttr( + const user_op::UserOpDefWrapper&, const user_op::UserOpConfWrapper& op_conf) { + CHECK_OR_RETURN(op_conf.output_size("out") >= 1); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/test_ops.cpp b/oneflow/user/ops/test_ops.cpp index efbff608f89..2b2249dc8f3 100644 --- a/oneflow/user/ops/test_ops.cpp +++ b/oneflow/user/ops/test_ops.cpp @@ -15,210 +15,206 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/common/balanced_splitter.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("ccrelu") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); - *out_shape = in_shape; - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe CcreluOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe CcreluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("x", 0); + Shape* out_shape = ctx->OutputShape("y", 0); + *out_shape = in_shape; + return Maybe::Ok(); +} +/*static*/ Maybe CcreluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe CcreluOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("ccrelu_grad") - .Input("y") - .Input("dy") - .Output("dx") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& y_shape = ctx->InputShape("y", 0); - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(dy_shape == y_shape); - *dx_shape = y_shape; - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("y", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder() - .Split(user_op::OpArg("y", 0), 0) - .Split(user_op::OpArg("dy", 0), 0) - .Split(user_op::OpArg("dx", 0), 0) - .Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe CcreluGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder() + .Split(user_op::OpArg("y", 0), 0) + .Split(user_op::OpArg("dy", 0), 0) + .Split(user_op::OpArg("dx", 0), 0) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe CcreluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& y_shape = ctx->InputShape("y", 0); + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(dy_shape == y_shape); + *dx_shape = y_shape; + return Maybe::Ok(); +} +/*static*/ Maybe CcreluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe CcreluGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("y", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("ccrelu").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { - if (op.NeedGenGradTensor4OpInput("in", 0)) { + if (op.NeedGenGradTensor4OpInput("x", 0)) { user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); user_op::UserOpConfWrapper ccrelu_grad_op = builder.Op("ccrelu_grad") - .Input("y", op.output("out", 0)) - .Input("dy", op.GetGradTensorWithOpOutput("out", 0)) + .Input("y", op.output("y", 0)) + .Input("dy", op.GetGradTensorWithOpOutput("y", 0)) .Output("dx") .Build(); - op.BindGradTensorWithOpInput(ccrelu_grad_op.output("dx", 0), "in", 0); + op.BindGradTensorWithOpInput(ccrelu_grad_op.output("dx", 0), "x", 0); AddOp(ccrelu_grad_op); } return Maybe::Ok(); }); -REGISTER_USER_OP("TestReshape") - .Input("in") - .Output("out") - .Attr("shape") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); - const Shape& conf_shape = ctx->Attr("shape"); - CHECK_EQ_OR_RETURN(in_shape.NumAxes(), conf_shape.NumAxes()); - *out_shape = conf_shape; - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/*static*/ Maybe TestReshapeOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe TestReshapeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + Shape* out_shape = ctx->OutputShape("out", 0); + const Shape& conf_shape = ctx->Attr("shape"); + CHECK_EQ_OR_RETURN(in_shape.NumAxes(), conf_shape.NumAxes()); + *out_shape = conf_shape; + return Maybe::Ok(); +} +/*static*/ Maybe TestReshapeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TestReshapeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("TestSource") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - Shape* out_shape = ctx->OutputShape("out", 0); - *out_shape = Shape({5}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = DataType::kFloat; - return Maybe::Ok(); - }); +/*static*/ Maybe TestSourceOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe TestSourceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + Shape* out_shape = ctx->OutputShape("out", 0); + *out_shape = Shape({5}); + return Maybe::Ok(); +} +/*static*/ Maybe TestSourceOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TestSourceOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = DataType::kFloat; + return Maybe::Ok(); +} -REGISTER_USER_OP("TestMultiOutputOrder") - .Input("in") - .Output("out1") - .Output("out2") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out1_shape = ctx->OutputShape("out1", 0); - Shape* out2_shape = ctx->OutputShape("out2", 0); - *out1_shape = in_shape; - *out2_shape = in_shape; - int32_t last_axis = in_shape.NumAxes() - 1; - out2_shape->Set(last_axis, in_shape.At(last_axis) * 2); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out1", 0) = ctx->InputDType("in", 0); - *ctx->OutputDType("out2", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe TestMultiOutputOrderOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe TestMultiOutputOrderOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + Shape* out1_shape = ctx->OutputShape("out1", 0); + Shape* out2_shape = ctx->OutputShape("out2", 0); + *out1_shape = in_shape; + *out2_shape = in_shape; + int32_t last_axis = in_shape.NumAxes() - 1; + out2_shape->Set(last_axis, in_shape.At(last_axis) * 2); + return Maybe::Ok(); +} +/*static*/ Maybe TestMultiOutputOrderOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TestMultiOutputOrderOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out1", 0) = ctx->InputDType("in", 0); + *ctx->OutputDType("out2", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("TestSourceMultiGpuFixedOutNum") - .Output("out") - .Attr("out_num") - .SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - Shape* out_shape = ctx->OutputShape("out", 0); - int64_t out_num = ctx->Attr("out_num"); - *out_shape = Shape({out_num}); - return Maybe::Ok(); - }) - .SetPhysicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - Shape* out_shape = ctx->OutputShape("out", 0); - int64_t out_num = ctx->Attr("out_num"); - const ParallelContext& parallel_ctx = ctx->parallel_ctx(); - BalancedSplitter bs(out_num, parallel_ctx.parallel_num()); - *out_shape = Shape({bs.At(parallel_ctx.parallel_id()).size()}); +/*static*/ Maybe TestSourceMultiGpuFixedOutNumOp::GetSbp(user_op::SbpContext* ctx) { + int64_t parallel_num = ctx->parallel_num(); + DeviceType device_type = ctx->device_type(); + if (device_type == DeviceType::kCPU && parallel_num > 1) { + ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe TestSourceMultiGpuFixedOutNumOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + Shape* out_shape = ctx->OutputShape("out", 0); + int64_t out_num = ctx->Attr("out_num"); + *out_shape = Shape({out_num}); + return Maybe::Ok(); +} +/*static*/ Maybe TestSourceMultiGpuFixedOutNumOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + Shape* out_shape = ctx->OutputShape("out", 0); + int64_t out_num = ctx->Attr("out_num"); + const ParallelContext& parallel_ctx = ctx->parallel_ctx(); + BalancedSplitter bs(out_num, parallel_ctx.parallel_num()); + *out_shape = Shape({bs.At(parallel_ctx.parallel_id()).size()}); - const cfg::SbpParallel& out_sbp = ctx->SbpParallel4ArgNameAndIndex("out", 0); - CHECK_OR_RETURN(out_sbp.has_split_parallel() && out_sbp.split_parallel().axis() == 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = DataType::kFloat; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - int64_t parallel_num = ctx->parallel_num(); - DeviceType device_type = ctx->device_type(); - if (device_type == DeviceType::kCPU && parallel_num > 1) { - ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); - } - return Maybe::Ok(); - }); + const cfg::SbpParallel& out_sbp = ctx->SbpParallel4ArgNameAndIndex("out", 0); + CHECK_OR_RETURN(out_sbp.has_split_parallel() && out_sbp.split_parallel().axis() == 0); + return Maybe::Ok(); +} +/*static*/ Maybe TestSourceMultiGpuFixedOutNumOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = DataType::kFloat; + return Maybe::Ok(); +} -REGISTER_USER_OP("TestMultiInput") - .Input("x1") - .Input("x2") - .Output("y") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x1_shape = ctx->InputShape("x1", 0); - const Shape& x2_shape = ctx->InputShape("x2", 0); - Shape* y_shape = ctx->OutputShape("y", 0); - CHECK_OR_RETURN(x1_shape == x2_shape); - *y_shape = x1_shape; - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x1", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x1_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x1", 0); - FOR_RANGE(int64_t, i, 0, x1_tensor.shape().NumAxes()) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - return Maybe::Ok(); - }); +/*static*/ Maybe TestMultiInputOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x1_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x1", 0); + FOR_RANGE(int64_t, i, 0, x1_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe TestMultiInputOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x1_shape = ctx->InputShape("x1", 0); + const Shape& x2_shape = ctx->InputShape("x2", 0); + Shape* y_shape = ctx->OutputShape("y", 0); + CHECK_OR_RETURN(x1_shape == x2_shape); + *y_shape = x1_shape; + return Maybe::Ok(); +} +/*static*/ Maybe TestMultiInputOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TestMultiInputOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x1", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("TestMultiInputGrad") - .Input("x1") - .Input("x2") - .Input("y_diff") - .Output("x1_diff") - .Output("x2_diff") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& x1_shape = ctx->InputShape("x1", 0); - const Shape& x2_shape = ctx->InputShape("x2", 0); - Shape* x1_diff_shape = ctx->OutputShape("x1_diff", 0); - Shape* x2_diff_shape = ctx->OutputShape("x2_diff", 0); - *x1_diff_shape = x1_shape; - *x2_diff_shape = x2_shape; - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("x1_diff", 0) = ctx->InputDType("x1", 0); - *ctx->OutputDType("x2_diff", 0) = ctx->InputDType("x2", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x1_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x1", 0); - FOR_RANGE(int64_t, i, 0, x1_tensor.shape().NumAxes()) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - return Maybe::Ok(); - }); +/*static*/ Maybe TestMultiInputGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x1_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x1", 0); + FOR_RANGE(int64_t, i, 0, x1_tensor.shape().NumAxes()) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe TestMultiInputGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& x1_shape = ctx->InputShape("x1", 0); + const Shape& x2_shape = ctx->InputShape("x2", 0); + Shape* x1_diff_shape = ctx->OutputShape("x1_diff", 0); + Shape* x2_diff_shape = ctx->OutputShape("x2_diff", 0); + *x1_diff_shape = x1_shape; + *x2_diff_shape = x2_shape; + return Maybe::Ok(); +} +/*static*/ Maybe TestMultiInputGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TestMultiInputGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("x1_diff", 0) = ctx->InputDType("x1", 0); + *ctx->OutputDType("x2_diff", 0) = ctx->InputDType("x2", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("TestMultiInput") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -240,110 +236,121 @@ REGISTER_USER_OP_GRAD("TestMultiInput") return Maybe::Ok(); }); -REGISTER_USER_OP("TestDynamicSource") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - *out_tensor->mut_shape() = Shape({5}); - out_tensor->set_is_dynamic(true); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = DataType::kFloat; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }) - .SetOutputArgModifyFn([](user_op::GetOutputArgModifier GetOutputArgModifierFn, - const user_op::UserOpConfWrapper& conf) -> Maybe { - user_op::OutputArgModifier* out_modifier = GetOutputArgModifierFn("out", 0); - CHECK_OR_RETURN(out_modifier != nullptr); - out_modifier->set_header_infered_before_compute(false); - return Maybe::Ok(); - }); +/*static*/ Maybe TestDynamicSourceOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe TestDynamicSourceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + *out_tensor->mut_shape() = Shape({5}); + out_tensor->set_is_dynamic(true); + return Maybe::Ok(); +} +/*static*/ Maybe TestDynamicSourceOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TestDynamicSourceOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = DataType::kFloat; + return Maybe::Ok(); +} +/*static*/ Maybe TestDynamicSourceOp::ModifyOutputArg( + const GetOutputArgModifier& GetOutputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::OutputArgModifier* out_modifier = GetOutputArgModifierFn("out", 0); + CHECK_OR_RETURN(out_modifier != nullptr); + out_modifier->set_header_infered_before_compute(false); + return Maybe::Ok(); +} -REGISTER_USER_OP("TestRandomSource") - .Output("out") - .Attr("seed") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); - *out_tensor->mut_shape() = Shape({5}); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = DataType::kFloat; - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/*static*/ Maybe TestRandomSourceOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe TestRandomSourceOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + user_op::TensorDesc* out_tensor = ctx->OutputTensorDesc("out", 0); + *out_tensor->mut_shape() = Shape({5}); + return Maybe::Ok(); +} +/*static*/ Maybe TestRandomSourceOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TestRandomSourceOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = DataType::kFloat; + return Maybe::Ok(); +} -REGISTER_USER_OP("TestDataTypeAttr") - .Input("in") - .Output("out") - .Attr("output_type") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); - *out_shape = in_shape; - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->Attr("output_type"); - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/*static*/ Maybe TestDataTypeAttrOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe TestDataTypeAttrOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + Shape* out_shape = ctx->OutputShape("out", 0); + *out_shape = in_shape; + return Maybe::Ok(); +} +/*static*/ Maybe TestDataTypeAttrOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TestDataTypeAttrOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->Attr("output_type"); + return Maybe::Ok(); +} -REGISTER_USER_OP("TestListDataTypeAndListShapeAndListStringAttr") - .Input("in") - .Output("out", 3) - .Attr>("out_shapes") - .Attr>("out_types") - .Attr>("string_list") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto& out_shapes = ctx->Attr>("out_shapes"); - const auto& string_list = ctx->Attr>("string_list"); - FOR_RANGE(int32_t, i, 0, ctx->outputs().size()) { - *ctx->OutputShape("out", i) = out_shapes.at(i); - } - CHECK_GT_OR_RETURN(string_list.size(), 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto& out_types = ctx->Attr>("out_types"); - FOR_RANGE(int32_t, i, 0, ctx->outputs().size()) { - *ctx->OutputDType("out", i) = out_types.at(i); - } - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/*static*/ Maybe TestListDataTypeAndListShapeAndListStringAttrOp::GetSbp( + user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe TestListDataTypeAndListShapeAndListStringAttrOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const auto& out_shapes = ctx->Attr>("out_shapes"); + const auto& string_list = ctx->Attr>("string_list"); + FOR_RANGE(int32_t, i, 0, ctx->outputs().size()) { + *ctx->OutputShape("out", i) = out_shapes.at(i); + } + CHECK_GT_OR_RETURN(string_list.size(), 0); + return Maybe::Ok(); +} +/*static*/ Maybe TestListDataTypeAndListShapeAndListStringAttrOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TestListDataTypeAndListShapeAndListStringAttrOp::InferDataType( + user_op::InferContext* ctx) { + const auto& out_types = ctx->Attr>("out_types"); + FOR_RANGE(int32_t, i, 0, ctx->outputs().size()) { *ctx->OutputDType("out", i) = out_types.at(i); } + return Maybe::Ok(); +} -REGISTER_USER_OP("test_user_op_attr_auto_type") - .Input("in") - .Output("out") - .Attr("int1") - .Attr("int2") - .SetTensorDescInferFn(user_op::TensorDescInferFnUtil::Unchanged) - .SetDataTypeInferFn(user_op::TensorDescInferFnUtil::UnchangedDataType) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/*static*/ Maybe TestUserOpAttrAutoTypeOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe TestUserOpAttrAutoTypeOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + return user_op::TensorDescInferFnUtil::Unchanged(ctx); +} +/*static*/ Maybe TestUserOpAttrAutoTypeOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TestUserOpAttrAutoTypeOp::InferDataType(user_op::InferContext* ctx) { + return user_op::TensorDescInferFnUtil::UnchangedDataType(ctx); +} -REGISTER_CPU_ONLY_USER_OP("cpu_only_relu_test") - .Input("in") - .Output("out") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const auto& in_desc = ctx->InputTensorDesc("in", 0); - auto* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_shape() = in_desc.shape(); - *out_desc->mut_is_dynamic() = in_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe CpuOnlyReluTestOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(ctx->inputs(), 0).Split(ctx->outputs(), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe CpuOnlyReluTestOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const auto& in_desc = ctx->InputTensorDesc("in", 0); + auto* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_shape() = in_desc.shape(); + *out_desc->mut_is_dynamic() = in_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ Maybe CpuOnlyReluTestOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe CpuOnlyReluTestOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("in", 0); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/tf_prelu_op.cpp b/oneflow/user/ops/tf_prelu_op.cpp index f264ecb2378..b4880e201e7 100644 --- a/oneflow/user/ops/tf_prelu_op.cpp +++ b/oneflow/user/ops/tf_prelu_op.cpp @@ -14,111 +14,106 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("tf_prelu") - .Input("x") - .Input("alpha") - .Output("y") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const Shape& alpha_shape = ctx->InputShape("alpha", 0); - CHECK_EQ_OR_RETURN(x_desc.shape().NumAxes(), alpha_shape.NumAxes() + 1); - FOR_RANGE(int64_t, i, 1, x_desc.shape().NumAxes()) { - CHECK_OR_RETURN((alpha_shape.At(i - 1) == x_desc.shape().At(i)) - || (alpha_shape.At(i - 1) == 1)); - } - *y_desc->mut_shape() = x_desc.shape(); - *y_desc->mut_is_dynamic() = x_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - const user_op::TensorDesc& alpha_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("alpha", 0); +/*static*/ Maybe TfPreluOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + const user_op::TensorDesc& alpha_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("alpha", 0); + ctx->NewBuilder() + .Split(user_op::OpArg("x", 0), 0) + .Broadcast(user_op::OpArg("alpha", 0)) + .Split(user_op::OpArg("y", 0), 0) + .Build(); + FOR_RANGE(int64_t, i, 1, x_tensor.shape().NumAxes()) { + if (x_tensor.shape().At(i) == alpha_tensor.shape().At(i - 1)) { ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), 0) - .Broadcast(user_op::OpArg("alpha", 0)) - .Split(user_op::OpArg("y", 0), 0) + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("alpha", 0), i - 1) + .Split(user_op::OpArg("y", 0), i) .Build(); - FOR_RANGE(int64_t, i, 1, x_tensor.shape().NumAxes()) { - if (x_tensor.shape().At(i) == alpha_tensor.shape().At(i - 1)) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("alpha", 0), i - 1) - .Split(user_op::OpArg("y", 0), i) - .Build(); - } - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); + } + } + return Maybe::Ok(); +} +/*static*/ Maybe TfPreluOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + const Shape& alpha_shape = ctx->InputShape("alpha", 0); + CHECK_EQ_OR_RETURN(x_desc.shape().NumAxes(), alpha_shape.NumAxes() + 1); + FOR_RANGE(int64_t, i, 1, x_desc.shape().NumAxes()) { + CHECK_OR_RETURN((alpha_shape.At(i - 1) == x_desc.shape().At(i)) + || (alpha_shape.At(i - 1) == 1)); + } + *y_desc->mut_shape() = x_desc.shape(); + *y_desc->mut_is_dynamic() = x_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ Maybe TfPreluOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TfPreluOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("tf_prelu_grad") - .Input("dy") - .Input("x") - .Input("alpha") - .Output("dx") - .Output("alpha_diff") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); - const user_op::TensorDesc& alpha_desc = ctx->InputTensorDesc("alpha", 0); - CHECK_EQ_OR_RETURN(x_desc.shape().NumAxes(), alpha_desc.shape().NumAxes() + 1); - FOR_RANGE(int64_t, i, 1, x_desc.shape().NumAxes()) { - CHECK_OR_RETURN((alpha_desc.shape().At(i - 1) == x_desc.shape().At(i)) - || (alpha_desc.shape().At(i - 1) == 1)); - } - CHECK_EQ_OR_RETURN(dy_desc.shape(), x_desc.shape()); - CHECK_EQ_OR_RETURN(dy_desc.data_type(), x_desc.data_type()); - *dx_desc->mut_shape() = x_desc.shape(); - *dx_desc->mut_is_dynamic() = x_desc.is_dynamic(); - *ctx->OutputShape("alpha_diff", 0) = alpha_desc.shape(); - *ctx->OutputIsDynamic("alpha_diff", 0) = alpha_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - const user_op::TensorDesc& alpha_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("alpha", 0); - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), 0) - .Split(user_op::OpArg("x", 0), 0) - .Broadcast(user_op::OpArg("alpha", 0)) - .Split(user_op::OpArg("dx", 0), 0) - .PartialSum(user_op::OpArg("alpha_diff", 0)) - .Build(); +/*static*/ Maybe TfPreluGradOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + const user_op::TensorDesc& alpha_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("alpha", 0); + ctx->NewBuilder() + .Split(user_op::OpArg("dy", 0), 0) + .Split(user_op::OpArg("x", 0), 0) + .Broadcast(user_op::OpArg("alpha", 0)) + .Split(user_op::OpArg("dx", 0), 0) + .PartialSum(user_op::OpArg("alpha_diff", 0)) + .Build(); + ctx->NewBuilder() + .PartialSum(user_op::OpArg("dy", 0)) + .Broadcast(user_op::OpArg("x", 0)) + .Broadcast(user_op::OpArg("alpha", 0)) + .PartialSum(user_op::OpArg("dx", 0)) + .PartialSum(user_op::OpArg("alpha_diff", 0)) + .Build(); + FOR_RANGE(int64_t, i, 1, x_tensor.shape().NumAxes()) { + if (x_tensor.shape().At(i) == alpha_tensor.shape().At(i - 1)) { ctx->NewBuilder() - .PartialSum(user_op::OpArg("dy", 0)) - .Broadcast(user_op::OpArg("x", 0)) - .Broadcast(user_op::OpArg("alpha", 0)) - .PartialSum(user_op::OpArg("dx", 0)) - .PartialSum(user_op::OpArg("alpha_diff", 0)) + .Split(user_op::OpArg("dy", 0), i) + .Split(user_op::OpArg("x", 0), i) + .Split(user_op::OpArg("alpha", 0), i - 1) + .Split(user_op::OpArg("dx", 0), i) + .Split(user_op::OpArg("alpha_diff", 0), i - 1) .Build(); - FOR_RANGE(int64_t, i, 1, x_tensor.shape().NumAxes()) { - if (x_tensor.shape().At(i) == alpha_tensor.shape().At(i - 1)) { - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("alpha", 0), i - 1) - .Split(user_op::OpArg("dx", 0), i) - .Split(user_op::OpArg("alpha_diff", 0), i - 1) - .Build(); - } - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); - *ctx->OutputDType("alpha_diff", 0) = ctx->InputDType("alpha", 0); - return Maybe::Ok(); - }); + } + } + return Maybe::Ok(); +} +/*static*/ Maybe TfPreluGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + const user_op::TensorDesc& dy_desc = ctx->InputTensorDesc("dy", 0); + user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + const user_op::TensorDesc& alpha_desc = ctx->InputTensorDesc("alpha", 0); + CHECK_EQ_OR_RETURN(x_desc.shape().NumAxes(), alpha_desc.shape().NumAxes() + 1); + FOR_RANGE(int64_t, i, 1, x_desc.shape().NumAxes()) { + CHECK_OR_RETURN((alpha_desc.shape().At(i - 1) == x_desc.shape().At(i)) + || (alpha_desc.shape().At(i - 1) == 1)); + } + CHECK_EQ_OR_RETURN(dy_desc.shape(), x_desc.shape()); + CHECK_EQ_OR_RETURN(dy_desc.data_type(), x_desc.data_type()); + *dx_desc->mut_shape() = x_desc.shape(); + *dx_desc->mut_is_dynamic() = x_desc.is_dynamic(); + *ctx->OutputShape("alpha_diff", 0) = alpha_desc.shape(); + *ctx->OutputIsDynamic("alpha_diff", 0) = alpha_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ Maybe TfPreluGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TfPreluGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("x", 0); + *ctx->OutputDType("alpha_diff", 0) = ctx->InputDType("alpha", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("tf_prelu") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/top_k_op.cpp b/oneflow/user/ops/top_k_op.cpp index 987196d8926..0bcf295d5bd 100644 --- a/oneflow/user/ops/top_k_op.cpp +++ b/oneflow/user/ops/top_k_op.cpp @@ -14,35 +14,33 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("top_k") - .Input("in") - .Output("out") - .Attr("k") - .Attr("sorted") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& in_shape = ctx->InputShape("in", 0); - Shape* out_shape = ctx->OutputShape("out", 0); - *out_shape = in_shape; - out_shape->Set( - in_shape.NumAxes() - 1, - std::min(ctx->Attr("k"), static_cast(in_shape.dim_vec().back()))); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = DataType::kInt64; - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - // The current implementation can only do top_k in the last dimension and should use Broadcast - // (by default) instead of Split for that dimension - const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - return Maybe::Ok(); - }); +/*static*/ Maybe TopKOp::GetSbp(user_op::SbpContext* ctx) { + // The current implementation can only do top_k in the last dimension and should use Broadcast + // (by default) instead of Split for that dimension + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes() - 1) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe TopKOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& in_shape = ctx->InputShape("in", 0); + Shape* out_shape = ctx->OutputShape("out", 0); + *out_shape = in_shape; + out_shape->Set(in_shape.NumAxes() - 1, std::min(ctx->Attr("k"), + static_cast(in_shape.dim_vec().back()))); + return Maybe::Ok(); +} +/*static*/ Maybe TopKOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TopKOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = DataType::kInt64; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/transpose_ops.cpp b/oneflow/user/ops/transpose_ops.cpp index 8a3b849ef7f..9d8130e6efb 100644 --- a/oneflow/user/ops/transpose_ops.cpp +++ b/oneflow/user/ops/transpose_ops.cpp @@ -13,9 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -29,43 +28,41 @@ void CheckIsPerm(const std::vector& perm) { } } -REGISTER_USER_OP("transpose") - .Input("input") - .Output("output") - .Attr>("perm") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc("input", 0); - user_op::TensorDesc* out_tensor_desc = ctx->OutputTensorDesc("output", 0); - const Shape& in_shape = in_tensor_desc.shape(); - Shape* out_shape = out_tensor_desc->mut_shape(); - const auto& perm = ctx->Attr>("perm"); - CHECK_EQ_OR_RETURN(perm.size(), in_shape.NumAxes()); - CheckIsPerm(perm); - // if (perm.at(0) != 0) { CHECK_OR_RETURN(!in_tensor_desc->is_dynamic()); } - *out_tensor_desc->mut_shape() = in_tensor_desc.shape(); - *out_tensor_desc->mut_is_dynamic() = in_tensor_desc.is_dynamic(); - FOR_RANGE(size_t, i, 0, perm.size()) { out_shape->Set(i, in_shape.At(perm[i])); } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("output", 0) = ctx->InputDType("input", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& input_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0); - const auto& perm = ctx->Attr>("perm"); - CHECK_EQ_OR_RETURN(perm.size(), input_tensor.shape().NumAxes()); - FOR_RANGE(int32_t, i, 0, perm.size()) { - int32_t axis = perm.at(i); - if (axis < 0) { axis += perm.size(); } - CHECK_GE_OR_RETURN(axis, 0); - CHECK_LT_OR_RETURN(axis, perm.size()); - ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), i).Build(); - } - ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe TransposeOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& input_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("input", 0); + const auto& perm = ctx->Attr>("perm"); + CHECK_EQ_OR_RETURN(perm.size(), input_tensor.shape().NumAxes()); + FOR_RANGE(int32_t, i, 0, perm.size()) { + int32_t axis = perm.at(i); + if (axis < 0) { axis += perm.size(); } + CHECK_GE_OR_RETURN(axis, 0); + CHECK_LT_OR_RETURN(axis, perm.size()); + ctx->NewBuilder().Split(ctx->inputs(), axis).Split(ctx->outputs(), i).Build(); + } + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe TransposeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc("input", 0); + user_op::TensorDesc* out_tensor_desc = ctx->OutputTensorDesc("output", 0); + const Shape& in_shape = in_tensor_desc.shape(); + Shape* out_shape = out_tensor_desc->mut_shape(); + const auto& perm = ctx->Attr>("perm"); + CHECK_EQ_OR_RETURN(perm.size(), in_shape.NumAxes()); + CheckIsPerm(perm); + // if (perm.at(0) != 0) { CHECK_OR_RETURN(!in_tensor_desc->is_dynamic()); } + *out_tensor_desc->mut_shape() = in_tensor_desc.shape(); + *out_tensor_desc->mut_is_dynamic() = in_tensor_desc.is_dynamic(); + FOR_RANGE(size_t, i, 0, perm.size()) { out_shape->Set(i, in_shape.At(perm[i])); } + return Maybe::Ok(); +} +/*static*/ Maybe TransposeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TransposeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("output", 0) = ctx->InputDType("input", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("transpose") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/tril_op.cpp b/oneflow/user/ops/tril_op.cpp index 7324a1a336a..933727beef0 100644 --- a/oneflow/user/ops/tril_op.cpp +++ b/oneflow/user/ops/tril_op.cpp @@ -14,46 +14,43 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("tril") - .Input("in") - .Output("out") - .Attr("diagonal") - .Attr("floating_fill_value", 0) - .Attr("integer_fill_value", 0) - .Attr("is_floating_fill_value", false) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - CHECK_GE_OR_RETURN(in.shape().NumAxes(), 2); - *out->mut_shape() = in.shape(); - *out->mut_is_dynamic() = in.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - *out->mut_data_type() = in.data_type(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in.shape().NumAxes() - 2) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - bool fill_zero = ctx->Attr("is_floating_fill_value") - ? (ctx->Attr("floating_fill_value") == static_cast(0)) - : (ctx->Attr("integer_fill_value") == static_cast(0)); - if (fill_zero) { - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - } - return Maybe::Ok(); - }); +/*static*/ Maybe TrilOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in.shape().NumAxes() - 2) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + bool fill_zero = ctx->Attr("is_floating_fill_value") + ? (ctx->Attr("floating_fill_value") == static_cast(0)) + : (ctx->Attr("integer_fill_value") == static_cast(0)); + if (fill_zero) { + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe TrilOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + CHECK_GE_OR_RETURN(in.shape().NumAxes(), 2); + *out->mut_shape() = in.shape(); + *out->mut_is_dynamic() = in.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ Maybe TrilOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TrilOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + *out->mut_data_type() = in.data_type(); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("tril").SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) -> Maybe { @@ -70,46 +67,39 @@ REGISTER_USER_OP_GRAD("tril").SetGenBackwardOpConfFn([](const user_op::UserOpWra return Maybe::Ok(); }); -REGISTER_USER_OP("fused_scale_tril") - .Input("in") - .Output("out") - .Attr("diagonal") - .Attr("floating_fill_value", 0) - .Attr("integer_fill_value", 0) - .Attr("is_floating_fill_value", false) - .Attr("floating_scale_value", 1) - .Attr("integer_scale_value", 1) - .Attr("is_floating_scale_value", false) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - CHECK_GE_OR_RETURN(in.shape().NumAxes(), 2); - *out->mut_shape() = in.shape(); - *out->mut_is_dynamic() = in.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - *out->mut_data_type() = in.data_type(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in.shape().NumAxes() - 2) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - bool fill_zero = ctx->Attr("is_floating_fill_value") - ? (ctx->Attr("floating_fill_value") == static_cast(0)) - : (ctx->Attr("integer_fill_value") == static_cast(0)); - if (fill_zero) { - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - } - return Maybe::Ok(); - }); +/*static*/ Maybe FusedScaleTrilOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in.shape().NumAxes() - 2) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + bool fill_zero = ctx->Attr("is_floating_fill_value") + ? (ctx->Attr("floating_fill_value") == static_cast(0)) + : (ctx->Attr("integer_fill_value") == static_cast(0)); + if (fill_zero) { + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + } + return Maybe::Ok(); +} +/*static*/ Maybe FusedScaleTrilOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + CHECK_GE_OR_RETURN(in.shape().NumAxes(), 2); + *out->mut_shape() = in.shape(); + *out->mut_is_dynamic() = in.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ Maybe FusedScaleTrilOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe FusedScaleTrilOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + *out->mut_data_type() = in.data_type(); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("fused_scale_tril") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/triu_op.cpp b/oneflow/user/ops/triu_op.cpp index 47a7e48f522..00448d7f585 100644 --- a/oneflow/user/ops/triu_op.cpp +++ b/oneflow/user/ops/triu_op.cpp @@ -14,37 +14,37 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("triu") - .Input("in") - .Output("out") - .Attr("diagonal") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - CHECK_GE_OR_RETURN(in.shape().NumAxes(), 2); - *out->mut_shape() = in.shape(); - *out->mut_is_dynamic() = in.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - *out->mut_data_type() = in.data_type(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in.shape().NumAxes() - 2) { - ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe TriuOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in.shape().NumAxes() - 2) { + ctx->NewBuilder().Split(ctx->inputs(), i).Split(ctx->outputs(), i).Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe TriuOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + CHECK_GE_OR_RETURN(in.shape().NumAxes(), 2); + *out->mut_shape() = in.shape(); + *out->mut_is_dynamic() = in.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ Maybe TriuOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TriuOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("in", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + *out->mut_data_type() = in.data_type(); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/tuple_identity_op.cpp b/oneflow/user/ops/tuple_identity_op.cpp index f829f9d72c5..b777e39fe5b 100644 --- a/oneflow/user/ops/tuple_identity_op.cpp +++ b/oneflow/user/ops/tuple_identity_op.cpp @@ -15,52 +15,60 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/operator.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("tuple_identity") - .InputWithMinimum("in", 1) - .OutputWithMinimum("out", 1) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const int64_t in_size = ctx->input_size("in"); - CHECK_EQ_OR_RETURN(ctx->output_size("out"), in_size); - for (int64_t i = 0; i < in_size; ++i) { - *ctx->OutputShape("out", i) = ctx->InputShape("in", i); - *ctx->IsDynamic4ArgNameAndIndex("out", i) = ctx->InputIsDynamic("in", i); - } - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const int64_t in_size = ctx->input_size("in"); - CHECK_EQ_OR_RETURN(ctx->output_size("out"), in_size); - for (int64_t i = 0; i < in_size; ++i) { - *ctx->OutputDType("out", i) = ctx->InputDType("in", i); - } - return Maybe::Ok(); - }) - .SetSbpSignatureInferFn([](user_op::InferSbpSignatureFnContext* ctx) -> Maybe { - cfg::SbpSignature* signature = ctx->mutable_sbp_signature(); - const cfg::SbpSignature& sbp_signature_conf = ctx->sbp_signature_conf(); - auto* bn2sbp = signature->mutable_bn_in_op2sbp_parallel(); - const auto& bn2conf_sbp = sbp_signature_conf.bn_in_op2sbp_parallel(); - const int64_t in_size = ctx->user_op_conf().input_size("in"); - CHECK_EQ_OR_RETURN(ctx->user_op_conf().output_size("out"), in_size); - for (int64_t i = 0; i < in_size; ++i) { - const cfg::SbpParallel* sbp_parallel = nullptr; - const std::string ibn = GenRepeatedBn("in", i); - const std::string& obn = GenRepeatedBn("out", i); - const auto& conf_sbp_it = bn2conf_sbp.find(obn); - if (conf_sbp_it == bn2conf_sbp.end()) { - sbp_parallel = &ctx->SbpParallelHint4InputArgNameAndIndex("in", i); - } else { - sbp_parallel = &conf_sbp_it->second; - } - (*bn2sbp)[ibn] = *sbp_parallel; - (*bn2sbp)[obn] = *sbp_parallel; - } - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); +/*static*/ Maybe TupleIdentityOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe TupleIdentityOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const int64_t in_size = ctx->input_size("in"); + CHECK_EQ_OR_RETURN(ctx->output_size("out"), in_size); + for (int64_t i = 0; i < in_size; ++i) { + *ctx->OutputShape("out", i) = ctx->InputShape("in", i); + *ctx->IsDynamic4ArgNameAndIndex("out", i) = ctx->InputIsDynamic("in", i); + } + return Maybe::Ok(); +} +/*static*/ Maybe TupleIdentityOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe TupleIdentityOp::InferDataType(user_op::InferContext* ctx) { + const int64_t in_size = ctx->input_size("in"); + CHECK_EQ_OR_RETURN(ctx->output_size("out"), in_size); + for (int64_t i = 0; i < in_size; ++i) { *ctx->OutputDType("out", i) = ctx->InputDType("in", i); } + return Maybe::Ok(); +} +/*static*/ Maybe TupleIdentityOp::InferSbpSignature( + user_op::InferSbpSignatureFnContext* ctx) { + cfg::SbpSignature* signature = ctx->mutable_sbp_signature(); + const cfg::SbpSignature& sbp_signature_conf = ctx->sbp_signature_conf(); + auto* bn2sbp = signature->mutable_bn_in_op2sbp_parallel(); + const auto& bn2conf_sbp = sbp_signature_conf.bn_in_op2sbp_parallel(); + const int64_t in_size = ctx->user_op_conf().input_size("in"); + CHECK_EQ_OR_RETURN(ctx->user_op_conf().output_size("out"), in_size); + for (int64_t i = 0; i < in_size; ++i) { + const cfg::SbpParallel* sbp_parallel = nullptr; + const std::string ibn = GenRepeatedBn("in", i); + const std::string& obn = GenRepeatedBn("out", i); + const auto& conf_sbp_it = bn2conf_sbp.find(obn); + if (conf_sbp_it == bn2conf_sbp.end()) { + sbp_parallel = &ctx->SbpParallelHint4InputArgNameAndIndex("in", i); + } else { + sbp_parallel = &conf_sbp_it->second; + } + (*bn2sbp)[ibn] = *sbp_parallel; + (*bn2sbp)[obn] = *sbp_parallel; + } + return Maybe::Ok(); +} +/*static*/ Maybe TupleIdentityOp::CheckAttr(const user_op::UserOpDefWrapper&, + const user_op::UserOpConfWrapper& op_conf) { + CHECK_OR_RETURN(op_conf.input_size("in") >= 1); + CHECK_OR_RETURN(op_conf.output_size("out") >= 1); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("tuple_identity") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/two_stage_reduce_ops.cpp b/oneflow/user/ops/two_stage_reduce_ops.cpp index 0d2176faaef..c9c1ac15c11 100644 --- a/oneflow/user/ops/two_stage_reduce_ops.cpp +++ b/oneflow/user/ops/two_stage_reduce_ops.cpp @@ -16,6 +16,7 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/core/operator/reduce_sbp_util.h" #include "oneflow/core/ndarray/binary_func.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -198,34 +199,41 @@ Maybe GetReduceDeviceStageGradSbpFn(user_op::SbpContext* ctx) { } // namespace -#define REGISTER_REDUCE_DEVICE_STAGE_USER_OP(op_name) \ - REGISTER_USER_OP(op_name) \ - .Input("in") \ - .Output("out") \ - .Output("mask") \ - .Output("count") \ - .Attr>("axis") \ - .SetLogicalTensorDescInferFn(InferReduceDeviceStageLogicalTensorDescFn) \ - .SetPhysicalTensorDescInferFn(InferReduceDeviceStagePhysicalTensorDescFn) \ - .SetDataTypeInferFn(InferReduceDeviceStageDtypeFn) \ - .SetGetSbpFn(GetReduceDeviceStageSbpFn); - -REGISTER_REDUCE_DEVICE_STAGE_USER_OP("reduce_min_device_stage") -REGISTER_REDUCE_DEVICE_STAGE_USER_OP("reduce_max_device_stage") - -#define REGISTER_REDUCE_DEVICE_STAGE_GRAD_USER_OP(op_name) \ - REGISTER_USER_OP(op_name) \ - .Input("out_diff") \ - .Input("mask") \ - .Input("count") \ - .Output("in_diff") \ - .Attr>("axis") \ - .SetTensorDescInferFn(InferReduceDeviceStageGradTensorDescFn) \ - .SetDataTypeInferFn(InferReduceDeviceStageGradDtypeFn) \ - .SetGetSbpFn(GetReduceDeviceStageGradSbpFn); - -REGISTER_REDUCE_DEVICE_STAGE_GRAD_USER_OP("reduce_min_device_stage_grad") -REGISTER_REDUCE_DEVICE_STAGE_GRAD_USER_OP("reduce_max_device_stage_grad") +#define IMPLEMENT_REDUCE_DEVICE_STAGE_USER_OP_FUNCS(op_name) \ + /*static*/ Maybe op_name##Op::GetSbp(user_op::SbpContext* ctx) { \ + return GetReduceDeviceStageSbpFn(ctx); \ + } \ + /*static*/ Maybe op_name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return InferReduceDeviceStageLogicalTensorDescFn(ctx); \ + } \ + /*static*/ Maybe op_name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferReduceDeviceStagePhysicalTensorDescFn(ctx); \ + } \ + /*static*/ Maybe op_name##Op::InferDataType(user_op::InferContext* ctx) { \ + return InferReduceDeviceStageDtypeFn(ctx); \ + } + +IMPLEMENT_REDUCE_DEVICE_STAGE_USER_OP_FUNCS(ReduceMinDeviceStage) +IMPLEMENT_REDUCE_DEVICE_STAGE_USER_OP_FUNCS(ReduceMaxDeviceStage) +#undef IMPLEMENT_REDUCE_DEVICE_STAGE_USER_OP_FUNCS + +#define IMPLEMENT_REDUCE_DEVICE_STAGE_USER_GRAD_OP_FUNCS(op_name) \ + /*static*/ Maybe op_name##GradOp::GetSbp(user_op::SbpContext* ctx) { \ + return GetReduceDeviceStageGradSbpFn(ctx); \ + } \ + /*static*/ Maybe op_name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return InferReduceDeviceStageGradTensorDescFn(ctx); \ + } \ + /*static*/ Maybe op_name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe op_name##GradOp::InferDataType(user_op::InferContext* ctx) { \ + return InferReduceDeviceStageGradDtypeFn(ctx); \ + } + +IMPLEMENT_REDUCE_DEVICE_STAGE_USER_GRAD_OP_FUNCS(ReduceMinDeviceStage) +IMPLEMENT_REDUCE_DEVICE_STAGE_USER_GRAD_OP_FUNCS(ReduceMaxDeviceStage) +#undef IMPLEMENT_REDUCE_DEVICE_STAGE_USER_GRAD_OP_FUNCS Maybe GenBackwardOpConf4ReduceDeviceStage(const std::string& op_type_name, const user_op::UserOpWrapper& op, @@ -255,58 +263,59 @@ Maybe GenBackwardOpConf4ReduceDeviceStage(const std::string& op_type_name, REGISTER_REDUCE_DEVICE_STAGE_USER_OP_GRAD("reduce_min_device_stage", "reduce_min_device_stage_grad") REGISTER_REDUCE_DEVICE_STAGE_USER_OP_GRAD("reduce_max_device_stage", "reduce_max_device_stage_grad") -#define REGISTER_REDUCE_GLOBAL_STAGE_USER_OP(op_name) \ - REGISTER_USER_OP(op_name) \ - .Input("in") \ - .Input("device_count") \ - .Output("out") \ - .Output("mask") \ - .Attr>("axis") \ - .Attr("keepdims") \ - .SetTensorDescInferFn(InferReduceGlobalStageTensorDescFn) \ - .SetDataTypeInferFn(InferReduceGlobalStageDtypeFn) \ - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, \ - const user_op::UserOpConfWrapper&) -> Maybe { \ - user_op::InputArgModifier* device_count_modifier = \ - GetInputArgModifierFn("device_count", 0); \ - device_count_modifier->set_requires_grad(false); \ - return Maybe::Ok(); \ - }) \ - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { \ - ctx->NewBuilder() \ - .Split(user_op::OpArg("in", 0), 0) \ - .Split(user_op::OpArg("device_count", 0), 0) \ - .Split(user_op::OpArg("out", 0), 0) \ - .Split(user_op::OpArg("mask", 0), 0) \ - .Build(); \ - return Maybe::Ok(); \ - }); - -REGISTER_REDUCE_GLOBAL_STAGE_USER_OP("reduce_min_global_stage") -REGISTER_REDUCE_GLOBAL_STAGE_USER_OP("reduce_max_global_stage") - -#define REGISTER_REDUCE_GLOBAL_STAGE_GRAD_USER_OP(op_name) \ - REGISTER_USER_OP(op_name) \ - .Input("out_diff") \ - .Input("mask") \ - .Input("device_count") \ - .Output("in_diff") \ - .Attr>("axis") \ - .Attr("keepdims") \ - .SetTensorDescInferFn(InferReduceGlobalStageGradTensorDescFn) \ - .SetDataTypeInferFn(InferReduceGlobalStageGradDtypeFn) \ - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { \ - ctx->NewBuilder() \ - .Split(user_op::OpArg("out_diff", 0), 0) \ - .Split(user_op::OpArg("mask", 0), 0) \ - .Split(user_op::OpArg("device_count", 0), 0) \ - .Split(user_op::OpArg("in_diff", 0), 0) \ - .Build(); \ - return Maybe::Ok(); \ - }); - -REGISTER_REDUCE_GLOBAL_STAGE_GRAD_USER_OP("reduce_min_global_stage_grad") -REGISTER_REDUCE_GLOBAL_STAGE_GRAD_USER_OP("reduce_max_global_stage_grad") +#define IMPLEMENT_REDUCE_GLOBAL_STAGE_OP_FUNCS(op_name) \ + /*static*/ Maybe op_name##Op::GetSbp(user_op::SbpContext* ctx) { \ + ctx->NewBuilder() \ + .Split(user_op::OpArg("in", 0), 0) \ + .Split(user_op::OpArg("device_count", 0), 0) \ + .Split(user_op::OpArg("out", 0), 0) \ + .Split(user_op::OpArg("mask", 0), 0) \ + .Build(); \ + return Maybe::Ok(); \ + } \ + /*static*/ Maybe op_name##Op::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return InferReduceGlobalStageTensorDescFn(ctx); \ + } \ + /*static*/ Maybe op_name##Op::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe op_name##Op::InferDataType(user_op::InferContext* ctx) { \ + return InferReduceGlobalStageDtypeFn(ctx); \ + } \ + /*static*/ Maybe op_name##Op::ModifyInputArg( \ + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { \ + user_op::InputArgModifier* device_count_modifier = GetInputArgModifierFn("device_count", 0); \ + device_count_modifier->set_requires_grad(false); \ + return Maybe::Ok(); \ + } + +IMPLEMENT_REDUCE_GLOBAL_STAGE_OP_FUNCS(ReduceMinGlobalStage) +IMPLEMENT_REDUCE_GLOBAL_STAGE_OP_FUNCS(ReduceMaxGlobalStage) +#undef IMPLEMENT_REDUCE_GLOBAL_STAGE_OP_FUNCS + +#define IMPLEMENT_REDUCE_GLOBAL_STAGE_GRAD_OP_FUNCS(op_name) \ + /*static*/ Maybe op_name##GradOp::GetSbp(user_op::SbpContext* ctx) { \ + ctx->NewBuilder() \ + .Split(user_op::OpArg("out_diff", 0), 0) \ + .Split(user_op::OpArg("mask", 0), 0) \ + .Split(user_op::OpArg("device_count", 0), 0) \ + .Split(user_op::OpArg("in_diff", 0), 0) \ + .Build(); \ + return Maybe::Ok(); \ + } \ + /*static*/ Maybe op_name##GradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { \ + return InferReduceGlobalStageGradTensorDescFn(ctx); \ + } \ + /*static*/ Maybe op_name##GradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { \ + return InferLogicalTensorDesc(ctx); \ + } \ + /*static*/ Maybe op_name##GradOp::InferDataType(user_op::InferContext* ctx) { \ + return InferReduceGlobalStageGradDtypeFn(ctx); \ + } + +IMPLEMENT_REDUCE_GLOBAL_STAGE_GRAD_OP_FUNCS(ReduceMinGlobalStage) +IMPLEMENT_REDUCE_GLOBAL_STAGE_GRAD_OP_FUNCS(ReduceMaxGlobalStage) +#undef IMPLEMENT_REDUCE_GLOBAL_STAGE_GRAD_OP_FUNCS Maybe GenBackwardOpConf4ReduceGlobalStage(const std::string& op_type_name, const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/unfold_fold_op.cpp b/oneflow/user/ops/unfold_fold_op.cpp index f4b531d5779..0560561604c 100644 --- a/oneflow/user/ops/unfold_fold_op.cpp +++ b/oneflow/user/ops/unfold_fold_op.cpp @@ -16,11 +16,10 @@ limitations under the License. #include "oneflow/core/framework/framework.h" #include "oneflow/user/ops/nn_util.h" #include "oneflow/core/operator/operator_util.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace user_op { - namespace { Maybe UnfoldTensorDescInferFn(user_op::InferContext* ctx) { @@ -135,30 +134,26 @@ Maybe GetFoldSbpFn(user_op::SbpContext* ctx) { } // namespace -REGISTER_USER_OP("unfold") - .Input("x") - .Output("y") - .Attr("data_format") - .Attr>("kernel_size") - .Attr>("padding") - .Attr>("strides") - .Attr>("dilation_rate") - .SetTensorDescInferFn(UnfoldTensorDescInferFn) - .SetGetSbpFn(GetUnfoldSbpFn) - .SetDataTypeInferFn(SetUnfoldDTypeFn); - -REGISTER_USER_OP("fold") - .Input("x") - .Output("y") - .Attr>("output_size") - .Attr>("kernel_size") - .Attr>("strides") - .Attr>("padding") - .Attr>("dilation_rate") - .SetTensorDescInferFn(FoldTensorDescInferFn) - .SetGetSbpFn(GetFoldSbpFn) - .SetDataTypeInferFn(FoldDTypeFn); - -} // namespace user_op - -} // namespace oneflow \ No newline at end of file +/*static*/ Maybe UnfoldOp::GetSbp(user_op::SbpContext* ctx) { return GetUnfoldSbpFn(ctx); } +/*static*/ Maybe UnfoldOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return UnfoldTensorDescInferFn(ctx); +} +/*static*/ Maybe UnfoldOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UnfoldOp::InferDataType(user_op::InferContext* ctx) { + return SetUnfoldDTypeFn(ctx); +} + +/*static*/ Maybe FoldOp::GetSbp(user_op::SbpContext* ctx) { return GetFoldSbpFn(ctx); } +/*static*/ Maybe FoldOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return FoldTensorDescInferFn(ctx); +} +/*static*/ Maybe FoldOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe FoldOp::InferDataType(user_op::InferContext* ctx) { + return FoldDTypeFn(ctx); +} + +} // namespace oneflow diff --git a/oneflow/user/ops/unfold_tensor_op.cpp b/oneflow/user/ops/unfold_tensor_op.cpp index 7a6b5f7586f..c383cfee652 100644 --- a/oneflow/user/ops/unfold_tensor_op.cpp +++ b/oneflow/user/ops/unfold_tensor_op.cpp @@ -15,95 +15,83 @@ limitations under the License. */ #include "oneflow/core/framework/framework.h" #include "oneflow/user/kernels/unfold_tensor_kernel_utils.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("unfold_tensor") - .Input("x") - .Output("y") - .Attr("dimension") - .Attr("size") - .Attr("step") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("x", 0); - const int32_t dimension = ctx->Attr("dimension"); - const int32_t size = ctx->Attr("size"); - const int32_t step = ctx->Attr("step"); +/*static*/ Maybe UnfoldTensorOp::GetSbp(user_op::SbpContext* ctx) { + const int32_t dimension = ctx->Attr("dimension"); + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + if (i != dimension) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), i).Split(user_op::OpArg("y", 0), i).Build(); + } + } + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UnfoldTensorOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("x", 0); + const int32_t dimension = ctx->Attr("dimension"); + const int32_t size = ctx->Attr("size"); + const int32_t step = ctx->Attr("step"); - const Shape& in_shape = ctx->InputShape("x", 0); - const int32_t in_dim = in_shape.NumAxes(); - CHECK_GE_OR_RETURN(dimension, 0); - CHECK_LE_OR_RETURN(dimension, in_dim - 1); + const Shape& in_shape = ctx->InputShape("x", 0); + const int32_t in_dim = in_shape.NumAxes(); + CHECK_GE_OR_RETURN(dimension, 0); + CHECK_LE_OR_RETURN(dimension, in_dim - 1); - const int32_t max_size = in_dim == 0 ? 1 : in_shape.At(dimension); - CHECK_GT_OR_RETURN(size, 0); - CHECK_LE_OR_RETURN(size, max_size); - CHECK_GT_OR_RETURN(step, 0); + const int32_t max_size = in_dim == 0 ? 1 : in_shape.At(dimension); + CHECK_GT_OR_RETURN(size, 0); + CHECK_LE_OR_RETURN(size, max_size); + CHECK_GT_OR_RETURN(step, 0); - DimVector out_shape(in_dim + 1); - out_shape[in_dim] = size; - FOR_RANGE(int32_t, d, 0, in_dim) { - int32_t in_size_at_d = in.shape().At(d); - if (d == dimension) { - out_shape.at(d) = (in_size_at_d - size) / step + 1; - } else { - out_shape.at(d) = in_size_at_d; - } - } - *ctx->OutputShape("y", 0) = Shape(out_shape); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const int32_t dimension = ctx->Attr("dimension"); - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("x", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - if (i != dimension) { - ctx->NewBuilder() - .Split(user_op::OpArg("x", 0), i) - .Split(user_op::OpArg("y", 0), i) - .Build(); - } - } - ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); - return Maybe::Ok(); - }); + DimVector out_shape(in_dim + 1); + out_shape[in_dim] = size; + FOR_RANGE(int32_t, d, 0, in_dim) { + int32_t in_size_at_d = in.shape().At(d); + if (d == dimension) { + out_shape.at(d) = (in_size_at_d - size) / step + 1; + } else { + out_shape.at(d) = in_size_at_d; + } + } + *ctx->OutputShape("y", 0) = Shape(out_shape); + return Maybe::Ok(); +} +/*static*/ Maybe UnfoldTensorOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UnfoldTensorOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("unfold_tensor_grad") - .Input("dy") - .Input("x") - .Output("dx") - .Attr("dimension") - .Attr("size") - .Attr("step") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->InputTensorDesc("x", 0); - const Shape& in_shape = in.shape(); - user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); - *dx_desc->mut_shape() = Shape(in_shape.dim_vec()); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const int32_t dimension = ctx->Attr("dimension"); - const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dx", 0); - FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { - if (i != dimension) { - ctx->NewBuilder() - .Split(user_op::OpArg("dy", 0), i) - .Split(user_op::OpArg("dx", 0), i) - .Build(); - } - } - ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe UnfoldTensorGradOp::GetSbp(user_op::SbpContext* ctx) { + const int32_t dimension = ctx->Attr("dimension"); + const user_op::TensorDesc& x_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("dx", 0); + FOR_RANGE(int64_t, i, 0, x_tensor.shape().NumAxes()) { + if (i != dimension) { + ctx->NewBuilder().Split(user_op::OpArg("dy", 0), i).Split(user_op::OpArg("dx", 0), i).Build(); + } + } + ctx->NewBuilder().PartialSum(ctx->inputs()).PartialSum(ctx->outputs()).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UnfoldTensorGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in = ctx->InputTensorDesc("x", 0); + const Shape& in_shape = in.shape(); + user_op::TensorDesc* dx_desc = ctx->OutputTensorDesc("dx", 0); + *dx_desc->mut_shape() = Shape(in_shape.dim_vec()); + return Maybe::Ok(); +} +/*static*/ Maybe UnfoldTensorGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UnfoldTensorGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("unfold_tensor") .SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { diff --git a/oneflow/user/ops/unique_with_counts_op.cpp b/oneflow/user/ops/unique_with_counts_op.cpp index bf643a7f377..ea0c120dfa7 100644 --- a/oneflow/user/ops/unique_with_counts_op.cpp +++ b/oneflow/user/ops/unique_with_counts_op.cpp @@ -14,52 +14,51 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("unique_with_counts") - .Input("x") - .Output("y") - .Output("idx") - .Output("count") - .Output("num_unique") - .Attr("out_idx", DataType::kInt32) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - CHECK_EQ_OR_RETURN(x.shape().NumAxes(), 1); +/*static*/ Maybe UniqueWithCountsOp::GetSbp(user_op::SbpContext* ctx) { + return user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast(ctx); +} +/*static*/ Maybe UniqueWithCountsOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + CHECK_EQ_OR_RETURN(x.shape().NumAxes(), 1); - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); - *y->mut_shape() = x.shape(); - *y->mut_is_dynamic() = x.is_dynamic(); + user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + *y->mut_shape() = x.shape(); + *y->mut_is_dynamic() = x.is_dynamic(); - user_op::TensorDesc* idx = ctx->OutputTensorDesc("idx", 0); - *idx->mut_shape() = x.shape(); - *idx->mut_is_dynamic() = x.is_dynamic(); + user_op::TensorDesc* idx = ctx->OutputTensorDesc("idx", 0); + *idx->mut_shape() = x.shape(); + *idx->mut_is_dynamic() = x.is_dynamic(); - user_op::TensorDesc* count = ctx->OutputTensorDesc("count", 0); - *count->mut_shape() = x.shape(); - *count->mut_is_dynamic() = x.is_dynamic(); + user_op::TensorDesc* count = ctx->OutputTensorDesc("count", 0); + *count->mut_shape() = x.shape(); + *count->mut_is_dynamic() = x.is_dynamic(); - user_op::TensorDesc* num_unique = ctx->OutputTensorDesc("num_unique", 0); - *num_unique->mut_shape() = Shape({1}); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); - auto out_idx = ctx->Attr("out_idx"); - CHECK_OR_RETURN(IsIndexDataType(out_idx)); - user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); - *y->mut_data_type() = x.data_type(); + user_op::TensorDesc* num_unique = ctx->OutputTensorDesc("num_unique", 0); + *num_unique->mut_shape() = Shape({1}); + return Maybe::Ok(); +} +/*static*/ Maybe UniqueWithCountsOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UniqueWithCountsOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0); + auto out_idx = ctx->Attr("out_idx"); + CHECK_OR_RETURN(IsIndexDataType(out_idx)); + user_op::TensorDesc* y = ctx->OutputTensorDesc("y", 0); + *y->mut_data_type() = x.data_type(); - user_op::TensorDesc* idx = ctx->OutputTensorDesc("idx", 0); - *idx->mut_data_type() = out_idx; + user_op::TensorDesc* idx = ctx->OutputTensorDesc("idx", 0); + *idx->mut_data_type() = out_idx; - user_op::TensorDesc* count = ctx->OutputTensorDesc("count", 0); - *count->mut_data_type() = out_idx; - user_op::TensorDesc* num_unique = ctx->OutputTensorDesc("num_unique", 0); - *num_unique->mut_data_type() = out_idx; - return Maybe::Ok(); - }) - .SetGetSbpFn(user_op::GetSbpFnUtil::DefaultBroadcastToBroadcast); + user_op::TensorDesc* count = ctx->OutputTensorDesc("count", 0); + *count->mut_data_type() = out_idx; + user_op::TensorDesc* num_unique = ctx->OutputTensorDesc("num_unique", 0); + *num_unique->mut_data_type() = out_idx; + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/unpack_op.cpp b/oneflow/user/ops/unpack_op.cpp index 0c5c589a70c..b0b4ee12f04 100644 --- a/oneflow/user/ops/unpack_op.cpp +++ b/oneflow/user/ops/unpack_op.cpp @@ -14,55 +14,50 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -namespace { - -REGISTER_USER_OP("unpack") - .Input("in") - .Output("out") - .Attr("unpack_num") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - const Shape& in_shape = in_desc.shape(); - CHECK_GT_OR_RETURN(in_shape.NumAxes(), 0); - const auto unpack_num = ctx->Attr("unpack_num"); - CHECK_EQ_OR_RETURN(in_shape.At(0) % unpack_num, 0); - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - *out_desc->mut_shape() = in_desc.shape(); - out_desc->mut_shape()->Set(0, in_shape.At(0) / unpack_num); - *out_desc->mut_is_dynamic() = in_desc.is_dynamic(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); - const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); - *out_desc->mut_data_type() = in_desc.data_type(); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); - FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("in", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("in", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }) - .SetOutputBlobTimeShapeInferFn( - [](user_op::InferOutputBlobTimeShapeFnContext* ctx) -> Maybe { - const int32_t unpack_num = ctx->user_op_conf().attr("unpack_num"); - DimVector time_shape_dim_vec = ctx->TimeShape4InputArgNameAndIndex("in", 0).dim_vec(); - time_shape_dim_vec.emplace_back(unpack_num); - *ctx->mut_output_blob_time_shape() = Shape(time_shape_dim_vec); - return Maybe::Ok(); - }); +/*static*/ Maybe UnpackOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& in = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in.shape().NumAxes()) { + ctx->NewBuilder().Split(user_op::OpArg("in", 0), i).Split(user_op::OpArg("out", 0), i).Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UnpackOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + const Shape& in_shape = in_desc.shape(); + CHECK_GT_OR_RETURN(in_shape.NumAxes(), 0); + const auto unpack_num = ctx->Attr("unpack_num"); + CHECK_EQ_OR_RETURN(in_shape.At(0) % unpack_num, 0); + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + *out_desc->mut_shape() = in_desc.shape(); + out_desc->mut_shape()->Set(0, in_shape.At(0) / unpack_num); + *out_desc->mut_is_dynamic() = in_desc.is_dynamic(); + return Maybe::Ok(); +} +/*static*/ Maybe UnpackOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UnpackOp::InferDataType(user_op::InferContext* ctx) { + user_op::TensorDesc* out_desc = ctx->OutputTensorDesc("out", 0); + const user_op::TensorDesc& in_desc = ctx->InputTensorDesc("in", 0); + *out_desc->mut_data_type() = in_desc.data_type(); + return Maybe::Ok(); +} +/*static*/ Maybe UnpackOp::InferOutputBlobTimeShape( + user_op::InferOutputBlobTimeShapeFnContext* ctx) { + const int32_t unpack_num = ctx->user_op_conf().attr("unpack_num"); + DimVector time_shape_dim_vec = ctx->TimeShape4InputArgNameAndIndex("in", 0).dim_vec(); + time_shape_dim_vec.emplace_back(unpack_num); + *ctx->mut_output_blob_time_shape() = Shape(time_shape_dim_vec); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("unpack").SetBackwardOpConfGenFn([](user_op::BackwardOpConfContext* ctx) -> Maybe { @@ -80,6 +75,4 @@ REGISTER_USER_OP_GRAD("unpack").SetBackwardOpConfGenFn([](user_op::BackwardOpCon return Maybe::Ok(); }); -} // namespace - } // namespace oneflow diff --git a/oneflow/user/ops/unsorted_batch_segment_sum_op.cpp b/oneflow/user/ops/unsorted_batch_segment_sum_op.cpp index bec3711097a..0ba58274570 100644 --- a/oneflow/user/ops/unsorted_batch_segment_sum_op.cpp +++ b/oneflow/user/ops/unsorted_batch_segment_sum_op.cpp @@ -14,69 +14,70 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("unsorted_batch_segment_sum") - .Input("data") - .Input("segment_ids") - .Output("out") - .Attr("num_segments") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& data = ctx->InputTensorDesc("data", 0); - const user_op::TensorDesc& segment_ids = ctx->InputTensorDesc("segment_ids", 0); - CHECK_GE_OR_RETURN(segment_ids.shape().NumAxes(), 1); - CHECK_GE_OR_RETURN(data.shape().NumAxes(), segment_ids.shape().NumAxes()); - CHECK_EQ_OR_RETURN(segment_ids.is_dynamic(), data.is_dynamic()); - const int64_t num_segments = ctx->Attr("num_segments"); - CHECK_GE_OR_RETURN(num_segments, 1); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); +/*static*/ Maybe UnsortedBatchSegmentSumOp::GetSbp(user_op::SbpContext* ctx) { + const int64_t segment_ids_num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("segment_ids", 0).shape().NumAxes(); + CHECK_GT_OR_RETURN(segment_ids_num_axes, 1) + << "UnsortedBatchSegmentSumOp: segment_ids_num_axes equals " << segment_ids_num_axes + << " (should be bigger than 1)."; - FOR_RANGE(int64_t, i, 0, segment_ids.shape().NumAxes() - 1) { - CHECK_EQ_OR_RETURN(segment_ids.shape().At(i), data.shape().At(i)); - } + FOR_RANGE(int64_t, i, 0, segment_ids_num_axes - 1) { + ctx->NewBuilder() + .Split(user_op::OpArg("segment_ids", 0), i) + .Split(user_op::OpArg("data", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + ctx->NewBuilder() + .Broadcast(user_op::OpArg("segment_ids", 0)) + .PartialSum(user_op::OpArg("data", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UnsortedBatchSegmentSumOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const user_op::TensorDesc& data = ctx->InputTensorDesc("data", 0); + const user_op::TensorDesc& segment_ids = ctx->InputTensorDesc("segment_ids", 0); + CHECK_GE_OR_RETURN(segment_ids.shape().NumAxes(), 1); + CHECK_GE_OR_RETURN(data.shape().NumAxes(), segment_ids.shape().NumAxes()); + CHECK_EQ_OR_RETURN(segment_ids.is_dynamic(), data.is_dynamic()); + const int64_t num_segments = ctx->Attr("num_segments"); + CHECK_GE_OR_RETURN(num_segments, 1); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - DimVector dim_vec(data.shape().dim_vec()); - dim_vec.at(segment_ids.shape().NumAxes() - 1) = num_segments; - *out->mut_shape() = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& data = ctx->InputTensorDesc("data", 0); - const user_op::TensorDesc& segment_ids = ctx->InputTensorDesc("segment_ids", 0); - user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); - CHECK_OR_RETURN(IsIndexDataType(segment_ids.data_type())); - *out->mut_data_type() = data.data_type(); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* segment_ids_modifier = GetInputArgModifierFn("segment_ids", 0); - CHECK_NOTNULL_OR_RETURN(segment_ids_modifier); - segment_ids_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const int64_t segment_ids_num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("segment_ids", 0).shape().NumAxes(); - CHECK_GT_OR_RETURN(segment_ids_num_axes, 1) - << "UnsortedBatchSegmentSumOp: segment_ids_num_axes equals " << segment_ids_num_axes - << " (should be bigger than 1)."; + FOR_RANGE(int64_t, i, 0, segment_ids.shape().NumAxes() - 1) { + CHECK_EQ_OR_RETURN(segment_ids.shape().At(i), data.shape().At(i)); + } - FOR_RANGE(int64_t, i, 0, segment_ids_num_axes - 1) { - ctx->NewBuilder() - .Split(user_op::OpArg("segment_ids", 0), i) - .Split(user_op::OpArg("data", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .Broadcast(user_op::OpArg("segment_ids", 0)) - .PartialSum(user_op::OpArg("data", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }); + DimVector dim_vec(data.shape().dim_vec()); + dim_vec.at(segment_ids.shape().NumAxes() - 1) = num_segments; + *out->mut_shape() = Shape(dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe UnsortedBatchSegmentSumOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UnsortedBatchSegmentSumOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& data = ctx->InputTensorDesc("data", 0); + const user_op::TensorDesc& segment_ids = ctx->InputTensorDesc("segment_ids", 0); + user_op::TensorDesc* out = ctx->OutputTensorDesc("out", 0); + CHECK_OR_RETURN(IsIndexDataType(segment_ids.data_type())); + *out->mut_data_type() = data.data_type(); + return Maybe::Ok(); +} +/*static*/ Maybe UnsortedBatchSegmentSumOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* segment_ids_modifier = GetInputArgModifierFn("segment_ids", 0); + CHECK_NOTNULL_OR_RETURN(segment_ids_modifier); + segment_ids_modifier->set_requires_grad(false); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("unsorted_batch_segment_sum") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/unsorted_segment_sum_op.cpp b/oneflow/user/ops/unsorted_segment_sum_op.cpp index 7dca75f50b0..5df5e81e451 100644 --- a/oneflow/user/ops/unsorted_segment_sum_op.cpp +++ b/oneflow/user/ops/unsorted_segment_sum_op.cpp @@ -14,74 +14,71 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("unsorted_segment_sum") - .Input("data") - .Input("segment_ids") - .Output("out") - .Attr("axis") - .Attr("num_segments") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& data_shape = ctx->InputShape("data", 0); - const int64_t axis = ctx->Attr("axis"); - const int64_t num_segments = ctx->Attr("num_segments"); - Shape* out_shape = ctx->OutputShape("out", 0); - const Shape& segment_ids_shape = ctx->InputShape("segment_ids", 0); +/*static*/ Maybe UnsortedSegmentSumOp::GetSbp(user_op::SbpContext* ctx) { + const int64_t data_num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("data", 0).shape().NumAxes(); + const int64_t segment_ids_num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("segment_ids", 0).shape().NumAxes(); + const int64_t axis = ctx->Attr("axis"); + FOR_RANGE(int64_t, i, 0, segment_ids_num_axes) { + ctx->NewBuilder() + .Split(user_op::OpArg("segment_ids", 0), i) + .Split(user_op::OpArg("data", 0), i + axis) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + } + FOR_RANGE(int64_t, i, 0, data_num_axes) { + if (i >= axis && i < axis + segment_ids_num_axes) { continue; } + const int64_t out_split_axis = (i < axis) ? i : i - segment_ids_num_axes + 1; + if (out_split_axis == axis) { continue; } + ctx->NewBuilder() + .Broadcast(user_op::OpArg("segment_ids", 0)) + .Split(user_op::OpArg("data", 0), i) + .Split(user_op::OpArg("out", 0), out_split_axis) + .Build(); + } + ctx->NewBuilder() + .Broadcast(user_op::OpArg("segment_ids", 0)) + .PartialSum(user_op::OpArg("data", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UnsortedSegmentSumOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& data_shape = ctx->InputShape("data", 0); + const int64_t axis = ctx->Attr("axis"); + const int64_t num_segments = ctx->Attr("num_segments"); + Shape* out_shape = ctx->OutputShape("out", 0); + const Shape& segment_ids_shape = ctx->InputShape("segment_ids", 0); - DimVector dim_vec; - dim_vec.insert(dim_vec.end(), data_shape.dim_vec().cbegin(), - data_shape.dim_vec().cbegin() + axis); - dim_vec.emplace_back(num_segments); - dim_vec.insert(dim_vec.end(), - data_shape.dim_vec().cbegin() + axis + segment_ids_shape.NumAxes(), - data_shape.dim_vec().end()); - *out_shape = Shape(dim_vec); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - CHECK_OR_RETURN(IsIndexDataType(ctx->InputDType("segment_ids", 0))); - *ctx->OutputDType("out", 0) = ctx->InputDType("data", 0); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* segment_ids_modifier = GetInputArgModifierFn("segment_ids", 0); - CHECK_NOTNULL_OR_RETURN(segment_ids_modifier); - segment_ids_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const int64_t data_num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("data", 0).shape().NumAxes(); - const int64_t segment_ids_num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("segment_ids", 0).shape().NumAxes(); - const int64_t axis = ctx->Attr("axis"); - FOR_RANGE(int64_t, i, 0, segment_ids_num_axes) { - ctx->NewBuilder() - .Split(user_op::OpArg("segment_ids", 0), i) - .Split(user_op::OpArg("data", 0), i + axis) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - } - FOR_RANGE(int64_t, i, 0, data_num_axes) { - if (i >= axis && i < axis + segment_ids_num_axes) { continue; } - const int64_t out_split_axis = (i < axis) ? i : i - segment_ids_num_axes + 1; - if (out_split_axis == axis) { continue; } - ctx->NewBuilder() - .Broadcast(user_op::OpArg("segment_ids", 0)) - .Split(user_op::OpArg("data", 0), i) - .Split(user_op::OpArg("out", 0), out_split_axis) - .Build(); - } - ctx->NewBuilder() - .Broadcast(user_op::OpArg("segment_ids", 0)) - .PartialSum(user_op::OpArg("data", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }); + DimVector dim_vec; + dim_vec.insert(dim_vec.end(), data_shape.dim_vec().cbegin(), + data_shape.dim_vec().cbegin() + axis); + dim_vec.emplace_back(num_segments); + dim_vec.insert(dim_vec.end(), data_shape.dim_vec().cbegin() + axis + segment_ids_shape.NumAxes(), + data_shape.dim_vec().end()); + *out_shape = Shape(dim_vec); + return Maybe::Ok(); +} +/*static*/ Maybe UnsortedSegmentSumOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UnsortedSegmentSumOp::InferDataType(user_op::InferContext* ctx) { + CHECK_OR_RETURN(IsIndexDataType(ctx->InputDType("segment_ids", 0))); + *ctx->OutputDType("out", 0) = ctx->InputDType("data", 0); + return Maybe::Ok(); +} +/*static*/ Maybe UnsortedSegmentSumOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* segment_ids_modifier = GetInputArgModifierFn("segment_ids", 0); + CHECK_NOTNULL_OR_RETURN(segment_ids_modifier); + segment_ids_modifier->set_requires_grad(false); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("unsorted_segment_sum") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, @@ -102,97 +99,95 @@ REGISTER_USER_OP_GRAD("unsorted_segment_sum") return Maybe::Ok(); }); -REGISTER_USER_OP("unsorted_segment_sum_like") - .Input("data") - .Input("segment_ids") - .Input("like") - .Output("out") - .Attr("axis") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& data_shape = ctx->InputShape("data", 0); - const Shape& like_shape = ctx->InputShape("like", 0); - const Shape& segment_ids_shape = ctx->InputShape("segment_ids", 0); - const int64_t axis = ctx->Attr("axis"); - CHECK_GE_OR_RETURN(axis, 0); - CHECK_LE_OR_RETURN(axis, like_shape.NumAxes()); - FOR_RANGE(int64_t, i, 0, axis) { CHECK_EQ_OR_RETURN(like_shape.At(i), data_shape.At(i)); } - CHECK_EQ_OR_RETURN(data_shape.NumAxes() - segment_ids_shape.NumAxes() + 1, - like_shape.NumAxes()); - FOR_RANGE(int64_t, i, axis + 1, like_shape.NumAxes()) { - CHECK_EQ_OR_RETURN(like_shape.At(i), data_shape.At(i + segment_ids_shape.NumAxes() - 1)); - } - *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); - *ctx->IsDynamic4ArgNameAndIndex("out", 0) = ctx->InputIsDynamic("like", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& data = ctx->InputTensorDesc("data", 0); - const user_op::TensorDesc& like = ctx->InputTensorDesc("like", 0); - CHECK_EQ_OR_RETURN(data.data_type(), like.data_type()); - CHECK_OR_RETURN(IsIndexDataType(ctx->InputDType("segment_ids", 0))); - *ctx->OutputDType("out", 0) = ctx->InputDType("like", 0); - return Maybe::Ok(); - }) - .SetInputArgModifyFn([](user_op::GetInputArgModifier GetInputArgModifierFn, - const user_op::UserOpConfWrapper&) -> Maybe { - user_op::InputArgModifier* segment_ids_modifier = GetInputArgModifierFn("segment_ids", 0); - CHECK_NOTNULL_OR_RETURN(segment_ids_modifier); - segment_ids_modifier->set_requires_grad(false); - user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); - CHECK_NOTNULL_OR_RETURN(like_modifier); - like_modifier->set_requires_grad(false); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const int64_t data_num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("data", 0).shape().NumAxes(); - const int64_t segment_ids_num_axes = - ctx->LogicalTensorDesc4InputArgNameAndIndex("segment_ids", 0).shape().NumAxes(); - const int64_t axis = ctx->Attr("axis"); - FOR_RANGE(int64_t, i, 0, segment_ids_num_axes) { - ctx->NewBuilder() - .Split(user_op::OpArg("segment_ids", 0), i) - .Split(user_op::OpArg("data", 0), i + axis) - .Broadcast(user_op::OpArg("like", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - ctx->NewBuilder() - .Split(user_op::OpArg("segment_ids", 0), i) - .Split(user_op::OpArg("data", 0), i + axis) - .PartialSum(user_op::OpArg("like", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - } - FOR_RANGE(int64_t, i, 0, data_num_axes) { - if (i >= axis && i < axis + segment_ids_num_axes) { continue; } - const int64_t out_split_axis = (i < axis) ? i : i - segment_ids_num_axes + 1; - if (out_split_axis == axis) { continue; } - ctx->NewBuilder() - .Broadcast(user_op::OpArg("segment_ids", 0)) - .Split(user_op::OpArg("data", 0), i) - .Split(user_op::OpArg("like", 0), out_split_axis) - .Split(user_op::OpArg("out", 0), out_split_axis) - .Build(); - } - ctx->NewBuilder() - .Broadcast(user_op::OpArg("segment_ids", 0)) - .PartialSum(user_op::OpArg("data", 0)) - .Broadcast(user_op::OpArg("like", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("segment_ids", 0)) - .PartialSum(user_op::OpArg("data", 0)) - .PartialSum(user_op::OpArg("like", 0)) - .PartialSum(user_op::OpArg("out", 0)) - .Build(); - ctx->NewBuilder() - .Broadcast(user_op::OpArg("segment_ids", 0)) - .Broadcast(user_op::OpArg("data", 0)) - .Split(user_op::OpArg("like", 0), axis) - .Split(user_op::OpArg("out", 0), axis) - .Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe UnsortedSegmentSumLikeOp::GetSbp(user_op::SbpContext* ctx) { + const int64_t data_num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("data", 0).shape().NumAxes(); + const int64_t segment_ids_num_axes = + ctx->LogicalTensorDesc4InputArgNameAndIndex("segment_ids", 0).shape().NumAxes(); + const int64_t axis = ctx->Attr("axis"); + FOR_RANGE(int64_t, i, 0, segment_ids_num_axes) { + ctx->NewBuilder() + .Split(user_op::OpArg("segment_ids", 0), i) + .Split(user_op::OpArg("data", 0), i + axis) + .Broadcast(user_op::OpArg("like", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + ctx->NewBuilder() + .Split(user_op::OpArg("segment_ids", 0), i) + .Split(user_op::OpArg("data", 0), i + axis) + .PartialSum(user_op::OpArg("like", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + } + FOR_RANGE(int64_t, i, 0, data_num_axes) { + if (i >= axis && i < axis + segment_ids_num_axes) { continue; } + const int64_t out_split_axis = (i < axis) ? i : i - segment_ids_num_axes + 1; + if (out_split_axis == axis) { continue; } + ctx->NewBuilder() + .Broadcast(user_op::OpArg("segment_ids", 0)) + .Split(user_op::OpArg("data", 0), i) + .Split(user_op::OpArg("like", 0), out_split_axis) + .Split(user_op::OpArg("out", 0), out_split_axis) + .Build(); + } + ctx->NewBuilder() + .Broadcast(user_op::OpArg("segment_ids", 0)) + .PartialSum(user_op::OpArg("data", 0)) + .Broadcast(user_op::OpArg("like", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("segment_ids", 0)) + .PartialSum(user_op::OpArg("data", 0)) + .PartialSum(user_op::OpArg("like", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + ctx->NewBuilder() + .Broadcast(user_op::OpArg("segment_ids", 0)) + .Broadcast(user_op::OpArg("data", 0)) + .Split(user_op::OpArg("like", 0), axis) + .Split(user_op::OpArg("out", 0), axis) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UnsortedSegmentSumLikeOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const Shape& data_shape = ctx->InputShape("data", 0); + const Shape& like_shape = ctx->InputShape("like", 0); + const Shape& segment_ids_shape = ctx->InputShape("segment_ids", 0); + const int64_t axis = ctx->Attr("axis"); + CHECK_GE_OR_RETURN(axis, 0); + CHECK_LE_OR_RETURN(axis, like_shape.NumAxes()); + FOR_RANGE(int64_t, i, 0, axis) { CHECK_EQ_OR_RETURN(like_shape.At(i), data_shape.At(i)); } + CHECK_EQ_OR_RETURN(data_shape.NumAxes() - segment_ids_shape.NumAxes() + 1, like_shape.NumAxes()); + FOR_RANGE(int64_t, i, axis + 1, like_shape.NumAxes()) { + CHECK_EQ_OR_RETURN(like_shape.At(i), data_shape.At(i + segment_ids_shape.NumAxes() - 1)); + } + *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); + *ctx->IsDynamic4ArgNameAndIndex("out", 0) = ctx->InputIsDynamic("like", 0); + return Maybe::Ok(); +} +/*static*/ Maybe UnsortedSegmentSumLikeOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UnsortedSegmentSumLikeOp::InferDataType(user_op::InferContext* ctx) { + const user_op::TensorDesc& data = ctx->InputTensorDesc("data", 0); + const user_op::TensorDesc& like = ctx->InputTensorDesc("like", 0); + CHECK_EQ_OR_RETURN(data.data_type(), like.data_type()); + CHECK_OR_RETURN(IsIndexDataType(ctx->InputDType("segment_ids", 0))); + *ctx->OutputDType("out", 0) = ctx->InputDType("like", 0); + return Maybe::Ok(); +} +/*static*/ Maybe UnsortedSegmentSumLikeOp::ModifyInputArg( + const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { + user_op::InputArgModifier* segment_ids_modifier = GetInputArgModifierFn("segment_ids", 0); + CHECK_NOTNULL_OR_RETURN(segment_ids_modifier); + segment_ids_modifier->set_requires_grad(false); + user_op::InputArgModifier* like_modifier = GetInputArgModifierFn("like", 0); + CHECK_NOTNULL_OR_RETURN(like_modifier); + like_modifier->set_requires_grad(false); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/oneflow/user/ops/upsample_op.cpp b/oneflow/user/ops/upsample_op.cpp index 17e5ca4f84b..0a48bbbfe1a 100644 --- a/oneflow/user/ops/upsample_op.cpp +++ b/oneflow/user/ops/upsample_op.cpp @@ -14,446 +14,386 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_USER_OP("upsample_linear_1d") - .Input("x") - .Output("y") - .Attr("scale_factor") - .Attr("align_corners") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const float scale_factor = ctx->Attr("scale_factor"); +/*static*/ Maybe UpsampleLinear1DOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleLinear1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + const float scale_factor = ctx->Attr("scale_factor"); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && x_desc.shape().NumAxes() == 3) - << "upsample_linear_1d only supports NCH"; - *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), - static_cast(scale_factor * x_desc.shape().At(2))}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && x_desc.shape().NumAxes() == 3) + << "upsample_linear_1d only supports NCH"; + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), + static_cast(scale_factor * x_desc.shape().At(2))}); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleLinear1DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleLinear1DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_nearest_1d") - .Input("x") - .Output("y") - .Attr("scale_factor") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const float scale_factor = ctx->Attr("scale_factor"); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && x_desc.shape().NumAxes() == 3) - << "upsample_nearest_1d only supports NCH"; - *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), - static_cast(scale_factor * x_desc.shape().At(2))}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleNearest1DOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleNearest1DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + const float scale_factor = ctx->Attr("scale_factor"); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && x_desc.shape().NumAxes() == 3) + << "upsample_nearest_1d only supports NCH"; + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), + static_cast(scale_factor * x_desc.shape().At(2))}); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleNearest1DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleNearest1DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_nearest_2d") - .Input("x") - .Output("y") - .Attr("height_scale") - .Attr("width_scale") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && x_desc.shape().NumAxes() == 4) - << "upsample_nearest_2d only supports NCHW"; - *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), - static_cast(height_scale * x_desc.shape().At(2)), - static_cast(width_scale * x_desc.shape().At(3))}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleNearest2DOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleNearest2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + const float height_scale = ctx->Attr("height_scale"); + const float width_scale = ctx->Attr("width_scale"); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && x_desc.shape().NumAxes() == 4) + << "upsample_nearest_2d only supports NCHW"; + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), + static_cast(height_scale * x_desc.shape().At(2)), + static_cast(width_scale * x_desc.shape().At(3))}); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleNearest2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleNearest2DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_bilinear_2d") - .Input("x") - .Output("y") - .Attr("height_scale") - .Attr("width_scale") - .Attr("align_corners") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && x_desc.shape().NumAxes() == 4) - << "upsample_bilinear_2d only supports NCHW"; - *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), - static_cast(height_scale * x_desc.shape().At(2)), - static_cast(width_scale * x_desc.shape().At(3))}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleBilinear2DOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleBilinear2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + const float height_scale = ctx->Attr("height_scale"); + const float width_scale = ctx->Attr("width_scale"); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && x_desc.shape().NumAxes() == 4) + << "upsample_bilinear_2d only supports NCHW"; + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), + static_cast(height_scale * x_desc.shape().At(2)), + static_cast(width_scale * x_desc.shape().At(3))}); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleBilinear2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleBilinear2DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_bicubic_2d") - .Input("x") - .Output("y") - .Attr("height_scale") - .Attr("width_scale") - .Attr("align_corners") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && x_desc.shape().NumAxes() == 4) - << "upsample_bicubic_2d only supports NCHW"; - *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), - static_cast(height_scale * x_desc.shape().At(2)), - static_cast(width_scale * x_desc.shape().At(3))}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleBicubic2DOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleBicubic2DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + const float height_scale = ctx->Attr("height_scale"); + const float width_scale = ctx->Attr("width_scale"); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && x_desc.shape().NumAxes() == 4) + << "upsample_bicubic_2d only supports NCHW"; + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), + static_cast(height_scale * x_desc.shape().At(2)), + static_cast(width_scale * x_desc.shape().At(3))}); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleBicubic2DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleBicubic2DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample") - .Input("x") - .Output("y") - .Attr("height_scale") - .Attr("width_scale") - .Attr("align_corners") - .Attr("data_format") - .Attr("interpolation") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); - if (ctx->Attr("data_format") != "channels_first" - || x_desc.shape().NumAxes() != 4) { - LOG(FATAL) << "upsample only supports NCHW"; - } - *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), - static_cast(height_scale * x_desc.shape().At(2)), - static_cast(width_scale * x_desc.shape().At(3))}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + const float height_scale = ctx->Attr("height_scale"); + const float width_scale = ctx->Attr("width_scale"); + if (ctx->Attr("data_format") != "channels_first" || x_desc.shape().NumAxes() != 4) { + LOG(FATAL) << "upsample only supports NCHW"; + } + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), + static_cast(height_scale * x_desc.shape().At(2)), + static_cast(width_scale * x_desc.shape().At(3))}); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_nearest_3d") - .Input("x") - .Output("y") - .Attr("depth_scale") - .Attr("height_scale") - .Attr("width_scale") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const float depth_scale = ctx->Attr("depth_scale"); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && x_desc.shape().NumAxes() == 5) - << "upsample_nearest_3d only supports NCDHW"; - *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), - static_cast(depth_scale * x_desc.shape().At(2)), - static_cast(height_scale * x_desc.shape().At(3)), - static_cast(width_scale * x_desc.shape().At(4))}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleNearest3DOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleNearest3DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + const float depth_scale = ctx->Attr("depth_scale"); + const float height_scale = ctx->Attr("height_scale"); + const float width_scale = ctx->Attr("width_scale"); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && x_desc.shape().NumAxes() == 5) + << "upsample_nearest_3d only supports NCDHW"; + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), + static_cast(depth_scale * x_desc.shape().At(2)), + static_cast(height_scale * x_desc.shape().At(3)), + static_cast(width_scale * x_desc.shape().At(4))}); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleNearest3DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleNearest3DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_trilinear_3d") - .Input("x") - .Output("y") - .Attr("depth_scale") - .Attr("height_scale") - .Attr("width_scale") - .Attr("align_corners") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); - user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); - const float depth_scale = ctx->Attr("depth_scale"); - const float height_scale = ctx->Attr("height_scale"); - const float width_scale = ctx->Attr("width_scale"); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && x_desc.shape().NumAxes() == 5) - << "upsample_trilinear_3d only supports NCDHW"; - *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), - static_cast(depth_scale * x_desc.shape().At(2)), - static_cast(height_scale * x_desc.shape().At(3)), - static_cast(width_scale * x_desc.shape().At(4))}); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleTrilinear3DOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("x", 0), 0).Split(user_op::OpArg("y", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleTrilinear3DOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const user_op::TensorDesc& x_desc = ctx->InputTensorDesc("x", 0); + user_op::TensorDesc* y_desc = ctx->OutputTensorDesc("y", 0); + const float depth_scale = ctx->Attr("depth_scale"); + const float height_scale = ctx->Attr("height_scale"); + const float width_scale = ctx->Attr("width_scale"); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && x_desc.shape().NumAxes() == 5) + << "upsample_trilinear_3d only supports NCDHW"; + *y_desc->mut_shape() = Shape({x_desc.shape().At(0), x_desc.shape().At(1), + static_cast(depth_scale * x_desc.shape().At(2)), + static_cast(height_scale * x_desc.shape().At(3)), + static_cast(width_scale * x_desc.shape().At(4))}); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleTrilinear3DOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleTrilinear3DOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("y", 0) = ctx->InputDType("x", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_linear_1d_grad") - .Input("dy") - .Input("x") - .Output("dx") - .Attr("scale_factor") - .Attr("align_corners") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && dy_shape.NumAxes() == 3) - << "upsample_linear_1d_grad only supports NCH"; - *dx_shape = ctx->InputShape("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleLinear1DGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleLinear1DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && dy_shape.NumAxes() == 3) + << "upsample_linear_1d_grad only supports NCH"; + *dx_shape = ctx->InputShape("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleLinear1DGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleLinear1DGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_nearest_1d_grad") - .Input("dy") - .Input("x") - .Output("dx") - .Attr("scale_factor") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && dy_shape.NumAxes() == 3) - << "upsample_nearest_1d_grad only supports NCH"; - *dx_shape = ctx->InputShape("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleNearest1DGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleNearest1DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && dy_shape.NumAxes() == 3) + << "upsample_nearest_1d_grad only supports NCH"; + *dx_shape = ctx->InputShape("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleNearest1DGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleNearest1DGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_nearest_2d_grad") - .Input("dy") - .Input("x") - .Output("dx") - .Attr("height_scale") - .Attr("width_scale") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && dy_shape.NumAxes() == 4) - << "upsample_nearest_2d_grad only supports NCHW"; - *dx_shape = ctx->InputShape("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleNearest2DGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleNearest2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && dy_shape.NumAxes() == 4) + << "upsample_nearest_2d_grad only supports NCHW"; + *dx_shape = ctx->InputShape("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleNearest2DGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleNearest2DGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_bilinear_2d_grad") - .Input("dy") - .Input("x") - .Output("dx") - .Attr("height_scale") - .Attr("width_scale") - .Attr("align_corners") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && dy_shape.NumAxes() == 4) - << "upsample_bilinear_2d_grad only supports NCHW"; - *dx_shape = ctx->InputShape("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleBilinear2DGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleBilinear2DGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && dy_shape.NumAxes() == 4) + << "upsample_bilinear_2d_grad only supports NCHW"; + *dx_shape = ctx->InputShape("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleBilinear2DGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleBilinear2DGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_bicubic_2d_grad") - .Input("dy") - .Input("x") - .Output("dx") - .Attr("height_scale") - .Attr("width_scale") - .Attr("align_corners") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && dy_shape.NumAxes() == 4) - << "upsample_bicubic_2d_grad only supports NCHW"; - *dx_shape = ctx->InputShape("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleBicubic2DGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleBicubic2DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && dy_shape.NumAxes() == 4) + << "upsample_bicubic_2d_grad only supports NCHW"; + *dx_shape = ctx->InputShape("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleBicubic2DGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleBicubic2DGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_grad") - .Input("dy") - .Input("x") - .Output("dx") - .Attr("height_scale") - .Attr("width_scale") - .Attr("align_corners") - .Attr("data_format") - .Attr("interpolation") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - if (ctx->Attr("data_format") != "channels_first" || dy_shape.NumAxes() != 4) { - LOG(FATAL) << "upsample_nearest only supports NCHW"; - } - *dx_shape = ctx->InputShape("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + if (ctx->Attr("data_format") != "channels_first" || dy_shape.NumAxes() != 4) { + LOG(FATAL) << "upsample_nearest only supports NCHW"; + } + *dx_shape = ctx->InputShape("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_nearest_3d_grad") - .Input("dy") - .Input("x") - .Output("dx") - .Attr("depth_scale") - .Attr("height_scale") - .Attr("width_scale") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && dy_shape.NumAxes() == 5) - << "upsample_nearest_3d_grad only supports NCDHW"; - *dx_shape = ctx->InputShape("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleNearest3DGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleNearest3DGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && dy_shape.NumAxes() == 5) + << "upsample_nearest_3d_grad only supports NCDHW"; + *dx_shape = ctx->InputShape("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleNearest3DGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleNearest3DGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} -REGISTER_USER_OP("upsample_trilinear_3d_grad") - .Input("dy") - .Input("x") - .Output("dx") - .Attr("depth_scale") - .Attr("height_scale") - .Attr("width_scale") - .Attr("align_corners") - .Attr("data_format") - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - const Shape& dy_shape = ctx->InputShape("dy", 0); - Shape* dx_shape = ctx->OutputShape("dx", 0); - CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" - && dy_shape.NumAxes() == 5) - << "upsample_trilinear_3d_grad only supports NCDHW"; - *dx_shape = ctx->InputShape("x", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); - return Maybe::Ok(); - }); +/*static*/ Maybe UpsampleTrilinear3DGradOp::GetSbp(user_op::SbpContext* ctx) { + ctx->NewBuilder().Split(user_op::OpArg("dy", 0), 0).Split(user_op::OpArg("dx", 0), 0).Build(); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleTrilinear3DGradOp::InferLogicalTensorDesc( + user_op::InferContext* ctx) { + const Shape& dy_shape = ctx->InputShape("dy", 0); + Shape* dx_shape = ctx->OutputShape("dx", 0); + CHECK_OR_RETURN(ctx->Attr("data_format") == "channels_first" + && dy_shape.NumAxes() == 5) + << "upsample_trilinear_3d_grad only supports NCDHW"; + *dx_shape = ctx->InputShape("x", 0); + return Maybe::Ok(); +} +/*static*/ Maybe UpsampleTrilinear3DGradOp::InferPhysicalTensorDesc( + user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe UpsampleTrilinear3DGradOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("dx", 0) = ctx->InputDType("dy", 0); + return Maybe::Ok(); +} REGISTER_USER_OP_GRAD("upsample_linear_1d") .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, diff --git a/oneflow/user/ops/where_op.cpp b/oneflow/user/ops/where_op.cpp index df26b3015e9..8dba2951a44 100644 --- a/oneflow/user/ops/where_op.cpp +++ b/oneflow/user/ops/where_op.cpp @@ -14,6 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { @@ -239,7 +240,7 @@ Maybe GetWhereXYScalarSbpSignatures(user_op::SbpContext* ctx) { return Maybe::Ok(); } -Maybe GetWhereInputArgModify(user_op::GetInputArgModifier GetInputArgModifierFn, +Maybe GetWhereInputArgModify(const GetInputArgModifier& GetInputArgModifierFn, const user_op::UserOpConfWrapper&) { user_op::InputArgModifier* cond_arg_modifier = GetInputArgModifierFn("condition", 0); cond_arg_modifier->set_requires_grad(false); @@ -248,101 +249,109 @@ Maybe GetWhereInputArgModify(user_op::GetInputArgModifier GetInputArgModif } // namespace -REGISTER_USER_OP("where") - .Input("condition") - .Input("x") - .Input("y") - .Output("out") - .SetTensorDescInferFn(InferWhereTensorDesc) - .SetInputArgModifyFn(GetWhereInputArgModify) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const DataType& cond_dtype = ctx->InputDType("condition", 0); - CHECK_OR_RETURN(IsIntegralDataType(cond_dtype)); - const DataType& x_dtype = ctx->InputDType("x", 0); - CHECK_EQ_OR_RETURN(x_dtype, ctx->InputDType("y", 0)); - *ctx->OutputDType("out", 0) = x_dtype; - return Maybe::Ok(); - }) - .SetGetSbpFn(GetWhereSbpSignatures); +/*static*/ Maybe WhereOp::GetSbp(user_op::SbpContext* ctx) { + return GetWhereSbpSignatures(ctx); +} +/*static*/ Maybe WhereOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferWhereTensorDesc(ctx); +} +/*static*/ Maybe WhereOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe WhereOp::InferDataType(user_op::InferContext* ctx) { + const DataType& cond_dtype = ctx->InputDType("condition", 0); + CHECK_OR_RETURN(IsIntegralDataType(cond_dtype)); + const DataType& x_dtype = ctx->InputDType("x", 0); + CHECK_EQ_OR_RETURN(x_dtype, ctx->InputDType("y", 0)); + *ctx->OutputDType("out", 0) = x_dtype; + return Maybe::Ok(); +} +/*static*/ Maybe WhereOp::ModifyInputArg(const GetInputArgModifier& f, + const user_op::UserOpConfWrapper& conf) { + return GetWhereInputArgModify(f, conf); +} -REGISTER_USER_OP("where_scalar_x") - .Input("condition") - .Input("y") - .Output("out") - .Attr("has_int_operand") - .Attr("has_float_operand") - .Attr("int_operand") - .Attr("float_operand") - .SetTensorDescInferFn(InferWhereXScalarTensorDesc) - .SetInputArgModifyFn(GetWhereInputArgModify) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const DataType& cond_dtype = ctx->InputDType("condition", 0); - CHECK_OR_RETURN(IsIntegralDataType(cond_dtype)); - const DataType& y_dtype = ctx->InputDType("y", 0); - if (ctx->Attr("has_int_operand")) { - CHECK_EQ_OR_RETURN(y_dtype, GetDataType::value) - << "expected scalar type " << GetDataType::value << "but found " << y_dtype; - } else if (ctx->Attr("has_float_operand")) { - CHECK_EQ_OR_RETURN(y_dtype, GetDataType::value) - << "expected scalar type " << GetDataType::value << "but found " << y_dtype; - } - *ctx->OutputDType("out", 0) = y_dtype; - return Maybe::Ok(); - }) - .SetGetSbpFn(GetWhereXScalarSbpSignatures); +/*static*/ Maybe WhereScalarXOp::GetSbp(user_op::SbpContext* ctx) { + return GetWhereXScalarSbpSignatures(ctx); +} +/*static*/ Maybe WhereScalarXOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferWhereXScalarTensorDesc(ctx); +} +/*static*/ Maybe WhereScalarXOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe WhereScalarXOp::InferDataType(user_op::InferContext* ctx) { + const DataType& cond_dtype = ctx->InputDType("condition", 0); + CHECK_OR_RETURN(IsIntegralDataType(cond_dtype)); + const DataType& y_dtype = ctx->InputDType("y", 0); + if (ctx->Attr("has_int_operand")) { + CHECK_EQ_OR_RETURN(y_dtype, GetDataType::value) + << "expected scalar type " << GetDataType::value << "but found " << y_dtype; + } else if (ctx->Attr("has_float_operand")) { + CHECK_EQ_OR_RETURN(y_dtype, GetDataType::value) + << "expected scalar type " << GetDataType::value << "but found " << y_dtype; + } + *ctx->OutputDType("out", 0) = y_dtype; + return Maybe::Ok(); +} +/*static*/ Maybe WhereScalarXOp::ModifyInputArg(const GetInputArgModifier& f, + const user_op::UserOpConfWrapper& conf) { + return GetWhereInputArgModify(f, conf); +} -REGISTER_USER_OP("where_scalar_y") - .Input("condition") - .Input("x") - .Output("out") - .Attr("has_int_operand") - .Attr("has_float_operand") - .Attr("int_operand") - .Attr("float_operand") - .SetTensorDescInferFn(InferWhereYScalarTensorDesc) - .SetInputArgModifyFn(GetWhereInputArgModify) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const DataType& cond_dtype = ctx->InputDType("condition", 0); - CHECK_OR_RETURN(IsIntegralDataType(cond_dtype)); - const DataType& x_dtype = ctx->InputDType("x", 0); - if (ctx->Attr("has_int_operand")) { - CHECK_EQ_OR_RETURN(x_dtype, GetDataType::value) - << "expected scalar type " << x_dtype << "but found " << GetDataType::value; - } else if (ctx->Attr("has_float_operand")) { - CHECK_EQ_OR_RETURN(x_dtype, GetDataType::value) - << "expected scalar type " << x_dtype << "but found " << GetDataType::value; - } - *ctx->OutputDType("out", 0) = x_dtype; - return Maybe::Ok(); - }) - .SetGetSbpFn(GetWhereYScalarSbpSignatures); +/*static*/ Maybe WhereScalarYOp::GetSbp(user_op::SbpContext* ctx) { + return GetWhereYScalarSbpSignatures(ctx); +} +/*static*/ Maybe WhereScalarYOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferWhereYScalarTensorDesc(ctx); +} +/*static*/ Maybe WhereScalarYOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe WhereScalarYOp::InferDataType(user_op::InferContext* ctx) { + const DataType& cond_dtype = ctx->InputDType("condition", 0); + CHECK_OR_RETURN(IsIntegralDataType(cond_dtype)); + const DataType& x_dtype = ctx->InputDType("x", 0); + if (ctx->Attr("has_int_operand")) { + CHECK_EQ_OR_RETURN(x_dtype, GetDataType::value) + << "expected scalar type " << x_dtype << "but found " << GetDataType::value; + } else if (ctx->Attr("has_float_operand")) { + CHECK_EQ_OR_RETURN(x_dtype, GetDataType::value) + << "expected scalar type " << x_dtype << "but found " << GetDataType::value; + } + *ctx->OutputDType("out", 0) = x_dtype; + return Maybe::Ok(); +} +/*static*/ Maybe WhereScalarYOp::ModifyInputArg(const GetInputArgModifier& f, + const user_op::UserOpConfWrapper& conf) { + return GetWhereInputArgModify(f, conf); +} -REGISTER_NO_GRAD_USER_OP("where_scalar_xy") - .Input("condition") - .Output("out") - .Attr("has_x_int_operand") - .Attr("has_x_float_operand") - .Attr("has_y_int_operand") - .Attr("has_y_float_operand") - .Attr("x_int_operand") - .Attr("x_float_operand") - .Attr("y_int_operand") - .Attr("y_float_operand") - .SetTensorDescInferFn(InferWhereXYScalarTensorDesc) - .SetInputArgModifyFn(GetWhereInputArgModify) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - const DataType& cond_dtype = ctx->InputDType("condition", 0); - CHECK_OR_RETURN(IsIntegralDataType(cond_dtype)); - if (ctx->Attr("has_x_int_operand") && ctx->Attr("has_y_int_operand")) { - *ctx->OutputDType("out", 0) = GetDataType::value; - } else if (ctx->Attr("has_x_float_operand") && ctx->Attr("has_y_float_operand")) { - *ctx->OutputDType("out", 0) = GetDataType::value; - } else { - UNIMPLEMENTED(); - } - return Maybe::Ok(); - }) - .SetGetSbpFn(GetWhereXYScalarSbpSignatures); +/*static*/ Maybe WhereScalarXyOp::GetSbp(user_op::SbpContext* ctx) { + return GetWhereXYScalarSbpSignatures(ctx); +} +/*static*/ Maybe WhereScalarXyOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + return InferWhereXYScalarTensorDesc(ctx); +} +/*static*/ Maybe WhereScalarXyOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe WhereScalarXyOp::InferDataType(user_op::InferContext* ctx) { + const DataType& cond_dtype = ctx->InputDType("condition", 0); + CHECK_OR_RETURN(IsIntegralDataType(cond_dtype)); + if (ctx->Attr("has_x_int_operand") && ctx->Attr("has_y_int_operand")) { + *ctx->OutputDType("out", 0) = GetDataType::value; + } else if (ctx->Attr("has_x_float_operand") && ctx->Attr("has_y_float_operand")) { + *ctx->OutputDType("out", 0) = GetDataType::value; + } else { + UNIMPLEMENTED(); + } + return Maybe::Ok(); +} +/*static*/ Maybe WhereScalarXyOp::ModifyInputArg(const GetInputArgModifier& f, + const user_op::UserOpConfWrapper& conf) { + return GetWhereInputArgModify(f, conf); +} REGISTER_USER_OP_GRAD("where").SetBackwardOpConfGenFn( [](user_op::BackwardOpConfContext* ctx) -> Maybe { diff --git a/oneflow/user/ops/zero_like_op.cpp b/oneflow/user/ops/zero_like_op.cpp index 193ad79666b..6e650556069 100644 --- a/oneflow/user/ops/zero_like_op.cpp +++ b/oneflow/user/ops/zero_like_op.cpp @@ -14,35 +14,34 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/core/framework/op_generated.h" namespace oneflow { -REGISTER_NO_GRAD_USER_OP("zero_like") - .Input("like") - .Output("out") - .SetOutputBufferNum(1) - .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); - return Maybe::Ok(); - }) - .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { - *ctx->OutputDType("out", 0) = ctx->InputDType("like", 0); - return Maybe::Ok(); - }) - .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { - const user_op::TensorDesc& like_tensor = - ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0); - FOR_RANGE(int64_t, i, 0, like_tensor.shape().NumAxes()) { - ctx->NewBuilder() - .Split(user_op::OpArg("like", 0), i) - .Split(user_op::OpArg("out", 0), i) - .Build(); - } - ctx->NewBuilder() - .PartialSum(user_op::OpArg("like", 0)) - .Broadcast(user_op::OpArg("out", 0)) - .Build(); - return Maybe::Ok(); - }); +/*static*/ Maybe ZeroLikeOp::GetSbp(user_op::SbpContext* ctx) { + const user_op::TensorDesc& like_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("like", 0); + FOR_RANGE(int64_t, i, 0, like_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("like", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("like", 0)) + .Broadcast(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); +} +/*static*/ Maybe ZeroLikeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) { + *ctx->OutputShape("out", 0) = ctx->InputShape("like", 0); + return Maybe::Ok(); +} +/*static*/ Maybe ZeroLikeOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) { + return InferLogicalTensorDesc(ctx); +} +/*static*/ Maybe ZeroLikeOp::InferDataType(user_op::InferContext* ctx) { + *ctx->OutputDType("out", 0) = ctx->InputDType("like", 0); + return Maybe::Ok(); +} } // namespace oneflow diff --git a/python/oneflow/__init__.py b/python/oneflow/__init__.py index 5c2df72e901..2cf61d151c2 100755 --- a/python/oneflow/__init__.py +++ b/python/oneflow/__init__.py @@ -15,6 +15,12 @@ """ import os + +if os.getenv("CTEST_RESOURCE_GROUP_COUNT"): + vram_str = os.getenv("CTEST_RESOURCE_GROUP_0_VRAM") + gpu_id = vram_str.split(",")[0].split(":")[-1] + os.environ["CUDA_VISIBLE_DEVICES"] = gpu_id + import sys import collections diff --git a/python/oneflow/compatible/single_client/ops/math_ops.py b/python/oneflow/compatible/single_client/ops/math_ops.py index f96a428bea6..19d720f80c1 100644 --- a/python/oneflow/compatible/single_client/ops/math_ops.py +++ b/python/oneflow/compatible/single_client/ops/math_ops.py @@ -594,8 +594,8 @@ def reluJob(x: tp.Numpy.Placeholder((3, )) return ( flow.user_op_builder(name if name is not None else id_util.UniqueStr("Relu_")) .Op("relu") - .Input("in", [x]) - .Output("out") + .Input("x", [x]) + .Output("y") .Build() .InferAndTryRun() .RemoteBlobList()[0] diff --git a/python/oneflow/compatible/single_client/test/ops/test_ccrelu.py b/python/oneflow/compatible/single_client/test/ops/test_ccrelu.py index 3ee8f22a532..3431d39121b 100644 --- a/python/oneflow/compatible/single_client/test/ops/test_ccrelu.py +++ b/python/oneflow/compatible/single_client/test/ops/test_ccrelu.py @@ -28,8 +28,8 @@ def ccrelu(x, name): return ( flow.user_op_builder(name) .Op("ccrelu") - .Input("in", [x]) - .Output("out") + .Input("x", [x]) + .Output("y") .Build() .InferAndTryRun() .RemoteBlobList()[0] diff --git a/python/oneflow/compatible/single_client/test/ops/test_multi_global_function.py b/python/oneflow/compatible/single_client/test/ops/test_multi_global_function.py index 4060d353959..ba5fd20eec9 100644 --- a/python/oneflow/compatible/single_client/test/ops/test_multi_global_function.py +++ b/python/oneflow/compatible/single_client/test/ops/test_multi_global_function.py @@ -28,8 +28,8 @@ def ccrelu(x, name): return ( flow.user_op_builder(name) .Op("ccrelu") - .Input("in", [x]) - .Output("out") + .Input("x", [x]) + .Output("y") .Build() .InferAndTryRun() .RemoteBlobList()[0] diff --git a/tools/oneflow-tblgen/CMakeLists.txt b/tools/oneflow-tblgen/CMakeLists.txt new file mode 100644 index 00000000000..fdcff1c70df --- /dev/null +++ b/tools/oneflow-tblgen/CMakeLists.txt @@ -0,0 +1,41 @@ +set(LLVM_LINK_COMPONENTS + Support +) +include(FetchContent) + +FetchContent_Declare( + inja +) +FetchContent_GetProperties(inja) +if(NOT inja_POPULATED) + FetchContent_Populate(inja + URL ${INJA_URL} + URL_HASH MD5=${INJA_MD5} + ) +endif() +include_directories(${inja_SOURCE_DIR}/include/inja) + +FetchContent_Declare( + json +) +FetchContent_GetProperties(json) +if(NOT json_POPULATED) + FetchContent_Populate(json + URL ${JSON_URL} + URL_HASH MD5=${JSON_MD5} + ) +endif() +include_directories(${json_SOURCE_DIR}/include) + +add_tablegen(oneflow_tblgen llvm + tablegen.cpp + op_schema_emitter.cpp +) + +install(TARGETS oneflow_tblgen LLVMTableGen LLVMDemangle LLVMSupport COMPONENT OneFlowTableGen) +add_custom_target(install-oneflow-tblgen + DEPENDS oneflow_tblgen + COMMAND + "${CMAKE_COMMAND}" -DCMAKE_INSTALL_COMPONENT=OneFlowTableGen + -P "${CMAKE_BINARY_DIR}/cmake_install.cmake" +) diff --git a/tools/oneflow-tblgen/backends.h b/tools/oneflow-tblgen/backends.h new file mode 100644 index 00000000000..d1da04ac2ea --- /dev/null +++ b/tools/oneflow-tblgen/backends.h @@ -0,0 +1,39 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#ifndef ONEFLOW_TBLGEN_BACKENDS_H +#define ONEFLOW_TBLGEN_BACKENDS_H + +namespace llvm { +class raw_ostream; +class RecordKeeper; +} // namespace llvm + +namespace oneflow { + +namespace tblgen { + +using llvm::raw_ostream; +using llvm::RecordKeeper; + +void EmitOpSchemaHeader(RecordKeeper& RK, raw_ostream& OS); +void EmitOpSchemaSource(RecordKeeper& RK, raw_ostream& OS); + +} // namespace tblgen + +} // namespace oneflow + +#endif // ONEFLOW_TBLGEN_BACKENDS_H diff --git a/tools/oneflow-tblgen/example/constant.td b/tools/oneflow-tblgen/example/constant.td new file mode 100644 index 00000000000..561d2999bfb --- /dev/null +++ b/tools/oneflow-tblgen/example/constant.td @@ -0,0 +1,17 @@ +include "mlir/Interfaces/SideEffectInterfaces.td" +include "OneFlowEnums.td" +include "OneFlowBase.td" + +def OneFlow_ConstantOp : OneFlow_BaseOp<"constant", [NoSideEffect, DeclareOpInterfaceMethods]> { + let output = (outs + AnyType:$out + ); + let attrs = (ins + DefaultValuedAttr:$floating_value, + DefaultValuedAttr:$integer_value, + DefaultValuedAttr:$is_floating_value, + StrAttr:$dtype, + AnyI64ElementsAttr:$shape, + StrArrayAttr:$nd_sbp + ); +} diff --git a/tools/oneflow-tblgen/op_schema_emitter.cpp b/tools/oneflow-tblgen/op_schema_emitter.cpp new file mode 100644 index 00000000000..bc313f78bb0 --- /dev/null +++ b/tools/oneflow-tblgen/op_schema_emitter.cpp @@ -0,0 +1,242 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Format.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/TableGen/Error.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/TableGenBackend.h" +#include "inja.hpp" + +#include +#include + +using namespace llvm; +using inja::json; + +namespace oneflow { +namespace tblgen { + +cl::OptionCategory opSchemaCat("Options for -gen-op-schema"); + +cl::opt sourceIncludeFilename{ + "op-include", cl::desc("header filename to include in source file"), + cl::value_desc("include filename"), cl::init(""), cl::cat(opSchemaCat)}; + +cl::opt dumpJson{"op-dump-json", + cl::desc("dump tablegen code to json in provided file"), + cl::value_desc("filename"), cl::init(""), cl::cat(opSchemaCat)}; + +enum class FileTarget { + kHeader = 1, + kSource, +}; + +template +class OpSchemaEmitter { + public: + explicit OpSchemaEmitter(RecordKeeper& RK); + + void run(raw_ostream& os); + + void emitInputAndOutput(const Record* def, json* op) const; + + void emitAttrs(const Record* def, json* op) const; + + void emitInt(const Record* def, StringRef fieldname, json* op) const; + void emitBit(const Record* def, StringRef fieldname, json* op) const; + void emitTrait(const Record* def, StringRef fieldname, StringRef traitname, json* op) const; + + private: + static std::string emitType(const std::string& ods_type) { +#define OP_SCHEMA(ods, cpp) \ + if (ods_type == #ods) return #cpp; +#include "op_schema_types.inc" +#undef OP_SCHEMA + PrintFatalError("undefined attribute type: " + ods_type); + } + + private: + RecordKeeper& records; + + StringRef op_type_name; + StringRef op_name; + + inja::Environment env; + inja::Template temp; + static const std::string code; +}; + +template +OpSchemaEmitter::OpSchemaEmitter(RecordKeeper& RK) : records(RK) { + env.add_callback("quoted", 1, [](inja::Arguments& args) { + auto str = args.at(0)->get(); + std::ostringstream os; + os << std::quoted(str); + return os.str(); + }); + env.add_callback("to_header", 1, [](inja::Arguments& args) { + auto str = args.at(0)->get(); + auto dot_pos = str.find_last_of('.'); + if (dot_pos != std::string::npos) { str.replace(dot_pos, str.size() - dot_pos, ".h"); } + + // assume that the source and header file is in the same directory + auto slash_pos = str.find_last_of('/'); + if (slash_pos != std::string::npos) { str.replace(0, slash_pos + 1, ""); } + return str; + }); + temp = env.parse(code); +} + +template +void OpSchemaEmitter::run(raw_ostream& os) { + emitSourceFileHeader("oneflow op schema", os); + json ops = json::object(); + + for (const auto& def : records.getAllDerivedDefinitions("OneFlow_BaseOp")) { + op_type_name = def->getValueAsString("opName"); + if (op_type_name.empty()) { + PrintFatalError(def, "`opName` of op definitions cannot be omitted"); + } + op_name = def->getName(); + if (!op_name.consume_front("OneFlow_")) { + PrintFatalError(def, "op name is not start with `OneFlow_`: " + op_name.str()); + } + json op{{"name", op_type_name}, + {"input", json::array()}, + {"output", json::array()}, + {"attrs", json::array()}}; + + emitInputAndOutput(def, &op); + emitAttrs(def, &op); + emitInt(def, "same_output_regst_num", &op); + emitTrait(def, "no_grad", "NoGrad", &op); + emitTrait(def, "cpu_only", "CpuOnly", &op); + emitBit(def, "has_nd_sbp_infer_fn", &op); + emitBit(def, "has_get_sbp_fn", &op); + emitBit(def, "has_logical_tensor_desc_infer_fn", &op); + emitBit(def, "has_physical_tensor_desc_infer_fn", &op); + emitBit(def, "has_data_type_infer_fn", &op); + emitBit(def, "has_device_infer_fn", &op); + emitBit(def, "has_input_arg_modify_fn", &op); + emitBit(def, "has_output_arg_modify_fn", &op); + emitBit(def, "has_output_blob_time_shape_infer_fn", &op); + emitBit(def, "has_sbp_signature_infer_fn", &op); + emitBit(def, "has_check_fn", &op); + ops[op_name.str()] = op; + } + + auto* option = static_cast*>(cl::getRegisteredOptions().lookup("o")); + auto filename = option->getValue(); + filename = filename != "-" ? filename : ""; + json data{{"filename", filename}, {"ops", ops}}; + + if (Target == FileTarget::kSource) { data["include"] = sourceIncludeFilename; } + if (!dumpJson.empty()) { + std::ofstream file(dumpJson); + file << data.dump(); + } + os << env.render(temp, data); +} + +template +void OpSchemaEmitter::emitInputAndOutput(const Record* def, json* op) const { + const auto* input = def->getValueAsDag("input"); + for (size_t i = 0; i < input->getNumArgs(); ++i) { + const auto* A = dyn_cast(input->getArg(i))->getDef(); + bool is_optional = A->isSubClassOf("Optional"); + auto NS = input->getArgName(i)->getAsUnquotedString(); + (*op)["input"].push_back({{"name", NS}, {"is_optional", is_optional}, {"size", 1}}); + } + const auto* output = def->getValueAsDag("output"); + for (size_t i = 0; i < output->getNumArgs(); ++i) { + const auto* A = dyn_cast(output->getArg(i))->getDef(); + bool is_optional = A->isSubClassOf("Optional"); + auto NS = output->getArgName(i)->getAsUnquotedString(); + (*op)["output"].push_back({{"name", NS}, {"is_optional", is_optional}, {"size", 1}}); + } +} + +template +void OpSchemaEmitter::emitAttrs(const Record* def, json* op) const { + const auto* attrs = def->getValueAsDag("attrs"); + for (size_t i = 0; i < attrs->getNumArgs(); ++i) { + const auto* A = dyn_cast(attrs->getArg(i))->getDef(); + std::string AS; + if (!A->isAnonymous()) { + AS = A->getNameInitAsString(); + } else { + AS = A->getValueAsDef("baseAttr")->getNameInitAsString(); + } + auto NS = attrs->getArgName(i)->getAsUnquotedString(); + json attr{{"name", NS}, {"type", emitType(AS)}}; + + if (auto DV = A->getValueAsOptionalString("defaultValue")) { attr["default"] = DV.getValue(); } + + (*op)["attrs"].push_back(attr); + } +} + +template +void OpSchemaEmitter::emitBit(const Record* def, StringRef fieldname, json* op) const { + (*op)[fieldname.str()] = def->getValueAsBit(fieldname); +} + +template +void OpSchemaEmitter::emitTrait(const Record* def, StringRef fieldname, StringRef traitname, + json* op) const { + bool hasTrait = false; + + for (auto elem : *def->getValueAsListInit("traits")) { + if (elem->getAsString() == traitname) { + hasTrait = true; + break; + } + } + + (*op)[fieldname.str()] = hasTrait; +} + +template +void OpSchemaEmitter::emitInt(const Record* def, StringRef fieldname, json* op) const { + (*op)[fieldname.str()] = def->getValueAsInt(fieldname); +} + +template<> +const std::string OpSchemaEmitter::code{ +#include "op_schema_header.inc" +}; + +template<> +const std::string OpSchemaEmitter::code{ +#include "op_schema_source.inc" +}; + +void EmitOpSchemaHeader(RecordKeeper& RK, raw_ostream& os) { + OpSchemaEmitter(RK).run(os); +} + +void EmitOpSchemaSource(RecordKeeper& RK, raw_ostream& os) { + OpSchemaEmitter(RK).run(os); +} + +} // namespace tblgen +} // namespace oneflow diff --git a/tools/oneflow-tblgen/op_schema_header.inc b/tools/oneflow-tblgen/op_schema_header.inc new file mode 100644 index 00000000000..ae167d08921 --- /dev/null +++ b/tools/oneflow-tblgen/op_schema_header.inc @@ -0,0 +1,100 @@ +R"OP_SCHEMA_INC( +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/common/shape.h" +#include "oneflow/core/common/symbol.h" +#include "oneflow/core/framework/op_base.h" + +#include +#include +#include + +namespace oneflow { + +class Device; +class InputBlobModifier; +class OutputBlobModifier; + +namespace user_op { +class UserOpDefWrapper; +class UserOpConfWrapper; +class InferContext; +class SbpContext; +class InferSbpSignatureFnContext; +class InferOutputBlobTimeShapeFnContext; +class InferNdSbpFnContext; +class DeviceInferContext; +} // namespace user_op + +using GetInputArgModifier = + std::function; +using GetOutputArgModifier = + std::function; + +{% for opname, op in ops %} +class {{opname}} : public OpBase { + public: + virtual ~{{opname}}() = default; + {% if op.has_nd_sbp_infer_fn -%} + static Maybe InferNdSbp(user_op::InferNdSbpFnContext* ctx); + {% endif -%} + {% if op.has_get_sbp_fn -%} + static Maybe GetSbp(user_op::SbpContext* ctx); + {% endif -%} + {% if op.has_logical_tensor_desc_infer_fn -%} + static Maybe InferLogicalTensorDesc(user_op::InferContext* ctx); + {% endif -%} + {% if op.has_physical_tensor_desc_infer_fn -%} + static Maybe InferPhysicalTensorDesc(user_op::InferContext* ctx); + {% endif -%} + {% if op.has_data_type_infer_fn -%} + static Maybe InferDataType(user_op::InferContext* ctx); + {% endif -%} + {% if op.has_device_infer_fn -%} + static Maybe> InferDevice(user_op::DeviceInferContext* ctx); + {% endif -%} + {% if op.has_sbp_signature_infer_fn -%} + static Maybe InferSbpSignature(user_op::InferSbpSignatureFnContext* ctx); + {% endif -%} + {% if op.has_input_arg_modify_fn -%} + static Maybe ModifyInputArg(const GetInputArgModifier&, const user_op::UserOpConfWrapper&); + {% endif -%} + {% if op.has_output_arg_modify_fn -%} + static Maybe ModifyOutputArg(const GetOutputArgModifier&, const user_op::UserOpConfWrapper&); + {% endif -%} + {% if op.has_output_blob_time_shape_infer_fn -%} + static Maybe InferOutputBlobTimeShape(user_op::InferOutputBlobTimeShapeFnContext* ctx); + {% endif -%} + {% if op.has_check_fn -%} + static Maybe CheckAttr(const user_op::UserOpDefWrapper&, const user_op::UserOpConfWrapper&); + {% endif -%} + + {% for attr in op.attrs -%} + virtual const {{attr.type}}& {{attr.name}}() const = 0; + virtual {{attr.type}}* mutable_{{attr.name}}() = 0; + virtual void set_{{attr.name}}(const {{attr.type}}& {{attr.name}}) = 0; + + {% endfor -%} + const HashSet& AttrNames() const; +}; + +namespace schema { +class {{opname}} : public oneflow::{{opname}} { + public: + {% for attr in op.attrs -%} + const {{attr.type}}& {{attr.name}}() const override { return {{attr.name}}_; } + {{attr.type}}* mutable_{{attr.name}}() override { return &{{attr.name}}_; } + void set_{{attr.name}}(const {{attr.type}}& {{attr.name}}) override { {{attr.name}}_ = {{attr.name}}; } + + {% endfor -%} + + Maybe GetAttr(const std::string& attr_name) const override; + + private: + {% for attr in op.attrs -%} + {{attr.type}} {{attr.name}}_{% if existsIn(attr, "default") %} = {{attr.default}}{% endif %}; + {% endfor %} +}; +} // namespace schema +{% endfor %} +} // namespace oneflow +)OP_SCHEMA_INC" diff --git a/tools/oneflow-tblgen/op_schema_source.inc b/tools/oneflow-tblgen/op_schema_source.inc new file mode 100644 index 00000000000..ceaa4b3d1b7 --- /dev/null +++ b/tools/oneflow-tblgen/op_schema_source.inc @@ -0,0 +1,106 @@ +R"OP_SCHEMA_INC( +{% if include != "" %}#include "{{ include }}" +{% else if filename != "" %}#include "{{ to_header(filename) }}" +{% endif %} +#include "oneflow/core/common/auto_registration_factory.h" +#include "oneflow/core/framework/attr_value.h" +#include "oneflow/core/framework/nd_sbp.h" +#include "oneflow/core/framework/infer_nd_sbp_fn_context.h" +#include "oneflow/core/framework/user_op_registry_manager.h" + +namespace oneflow { + +#define REGISTER_OP_SCHEMA(op_type, schema) \ + REGISTER_CLASS_CREATOR(std::string, op_type, OpBase, ([]() { return new schema; })) + +{% for opname, op in ops %} +const HashSet& {{opname}}::AttrNames() const { + static const HashSet attr_names = { {%- for attr in op.attrs -%}"{{attr.name}}", {%- endfor -%} }; + return attr_names; +} + +namespace schema { +Maybe {{opname}}::GetAttr(const std::string& attr_name) const { + {% for attr in op.attrs %}if(attr_name == "{{attr.name}}") { + return CastAttrValue(&{{attr.name}}_); + } + {% endfor -%} + return Error::RuntimeError() << "{{op.name}} op has no attribute named " << attr_name; +} +} // namespace schema + +REGISTER_OP_SCHEMA("user.{{op.name}}", schema::{{opname}}); + +REGISTER_USER_OP("{{op.name}}") +{%- if op.input -%} +{%- for input in op.input -%} +{%- if input.is_optional -%} + .OptionalInput("{{input.name}}") +{%- else -%} + .Input("{{input.name}}") +{%- endif -%} +{%- endfor -%} +{%- endif -%} +{%- if op.output -%} +{%- for output in op.output -%} +{%- if output.is_optional -%} + .OptionalOutput("{{output.name}}") +{%- else -%} + .Output("{{output.name}}") +{%- endif -%} +{%- endfor -%} +{%- endif -%} + +{%- for attr in op.attrs -%} +{%- if existsIn(attr, "default") -%} + .Attr<{{attr.type}}>("{{attr.name}}", {{attr.default}}) +{%- else -%} + .Attr<{{attr.type}}>("{{attr.name}}") +{%- endif -%} +{%- endfor -%} +{%- if op.cpu_only -%} + .SupportCpuOnly() +{%- endif -%} +{%- if op.no_grad -%} + .NoGrad() +{%- endif -%} +{%- if op.same_output_regst_num != -1 -%} + .SetOutputBufferNum({{op.same_output_regst_num}}) +{%- endif -%} +{%- if op.has_nd_sbp_infer_fn -%} + .SetNdSbpInferFn(&{{opname}}::InferNdSbp) +{%- endif -%} +{%- if op.has_get_sbp_fn -%} + .SetGetSbpFn(&{{opname}}::GetSbp) +{%- endif -%} +{%- if op.has_logical_tensor_desc_infer_fn -%} + .SetLogicalTensorDescInferFn(&{{opname}}::InferLogicalTensorDesc) +{%- endif -%} +{%- if op.has_physical_tensor_desc_infer_fn -%} + .SetPhysicalTensorDescInferFn(&{{opname}}::InferPhysicalTensorDesc) +{%- endif -%} +{%- if op.has_data_type_infer_fn -%} + .SetDataTypeInferFn(&{{opname}}::InferDataType) +{%- endif -%} +{%- if op.has_device_infer_fn -%} + .SetDeviceInferFn(&{{opname}}::InferDevice) +{%- endif -%} +{%- if op.has_sbp_signature_infer_fn -%} + .SetSbpSignatureInferFn(&{{opname}}::InferSbpSignature) +{% endif -%} +{%- if op.has_input_arg_modify_fn -%} + .SetInputArgModifyFn(&{{opname}}::ModifyInputArg) +{%- endif -%} +{%- if op.has_output_arg_modify_fn -%} + .SetOutputArgModifyFn(&{{opname}}::ModifyOutputArg) +{%- endif -%} +{%- if op.has_output_blob_time_shape_infer_fn -%} + .SetOutputBlobTimeShapeInferFn(&{{opname}}::InferOutputBlobTimeShape) +{%- endif -%} +{%- if op.has_check_fn -%} + .SetCheckAttrFn(&{{opname}}::CheckAttr) +{%- endif -%} +; +{%- endfor %} +} // namespace oneflow +)OP_SCHEMA_INC" diff --git a/tools/oneflow-tblgen/op_schema_types.inc b/tools/oneflow-tblgen/op_schema_types.inc new file mode 100644 index 00000000000..62656abca62 --- /dev/null +++ b/tools/oneflow-tblgen/op_schema_types.inc @@ -0,0 +1,14 @@ +OP_SCHEMA(SI32Attr, int32_t) +OP_SCHEMA(SI64Attr, int64_t) +OP_SCHEMA(BoolAttr, bool) +OP_SCHEMA(F32Attr, float) +OP_SCHEMA(F64Attr, double) +OP_SCHEMA(StrAttr, std::string) +OP_SCHEMA(ShapeAttr, Shape) +OP_SCHEMA(OneFlow_DataType, DataType) +OP_SCHEMA(SI32ArrayAttr, std::vector) +OP_SCHEMA(SI64ArrayAttr, std::vector) +OP_SCHEMA(F32ArrayAttr, std::vector) +OP_SCHEMA(DTArrayAttr, std::vector) +OP_SCHEMA(ShapeArrayAttr, std::vector) +OP_SCHEMA(StrArrayAttr, std::vector) diff --git a/tools/oneflow-tblgen/tablegen.cpp b/tools/oneflow-tblgen/tablegen.cpp new file mode 100644 index 00000000000..a2e3f4c038f --- /dev/null +++ b/tools/oneflow-tblgen/tablegen.cpp @@ -0,0 +1,104 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/TableGen/Main.h" +#include "llvm/TableGen/Record.h" +#include "llvm/TableGen/SetTheory.h" + +#include "backends.h" + +using namespace llvm; +using namespace oneflow::tblgen; + +enum ActionType { + PrintRecords, + PrintDetailedRecords, + NullBackend, + DumpJSON, + PrintEnums, + PrintSets, + GenOpSchemaHeader, + GenOpSchemaSource, +}; + +namespace llvm { +cl::opt EmitLongStrLiterals( + "long-string-literals", + cl::desc("when emitting large string tables, prefer string literals over " + "comma-separated char literals. This can be a readability and " + "compile-time performance win, but upsets some compilers"), + cl::Hidden, cl::init(true)); +} // end namespace llvm + +namespace { +cl::opt Action( + cl::desc("Action to perform:"), + cl::values(clEnumValN(PrintRecords, "print-records", "Print all records to stdout (default)"), + clEnumValN(PrintDetailedRecords, "print-detailed-records", + "Print full details of all records to stdout"), + clEnumValN(NullBackend, "null-backend", + "Do nothing after parsing (useful for timing)"), + clEnumValN(DumpJSON, "dump-json", "Dump all records as machine-readable JSON"), + clEnumValN(PrintEnums, "print-enums", "Print enum values for a class"), + clEnumValN(PrintSets, "print-sets", "Print expanded sets for testing DAG exprs"), + clEnumValN(GenOpSchemaHeader, "gen-op-schema-h", + "Generate oneflow op schema header code (.h)"), + clEnumValN(GenOpSchemaSource, "gen-op-schema-cpp", + "Generate oneflow op schema source code (.cpp)"))); + +cl::OptionCategory PrintEnumsCat("Options for -print-enums"); +cl::opt Class("class", cl::desc("Print Enum list for this class"), + cl::value_desc("class name"), cl::cat(PrintEnumsCat)); + +bool LLVMTableGenMain(raw_ostream& OS, RecordKeeper& Records) { + switch (Action) { + case PrintRecords: OS << Records; break; + case PrintDetailedRecords: EmitDetailedRecords(Records, OS); break; + case NullBackend: break; + case DumpJSON: EmitJSON(Records, OS); break; + case PrintEnums: { + for (Record* Rec : Records.getAllDerivedDefinitions(Class)) OS << Rec->getName() << ", "; + OS << "\n"; + break; + } + case PrintSets: { + SetTheory Sets; + Sets.addFieldExpander("Set", "Elements"); + for (Record* Rec : Records.getAllDerivedDefinitions("Set")) { + OS << Rec->getName() << " = ["; + const std::vector* Elts = Sets.expand(Rec); + assert(Elts && "Couldn't expand Set instance"); + for (Record* Elt : *Elts) OS << ' ' << Elt->getName(); + OS << " ]\n"; + } + break; + } + case GenOpSchemaHeader: EmitOpSchemaHeader(Records, OS); break; + case GenOpSchemaSource: EmitOpSchemaSource(Records, OS); break; + } + + return false; +} +} // namespace + +int main(int argc, char** argv) { + InitLLVM X(argc, argv); + cl::ParseCommandLineOptions(argc, argv); + + return TableGenMain(argv[0], &LLVMTableGenMain); +}