@@ -544,8 +544,6 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
544544
545545#ifdef USE_ROCM
546546namespace {
547- // Static functor type checker for binary functors with
548- // float as the type of both parameters.
549547template <
550548 typename TupleLike,
551549 typename FirstParamTy,
@@ -554,12 +552,11 @@ template <
554552 size_t arg_num = 0 >
555553struct check_binary_functor_types_for_specialization {
556554 constexpr static inline bool check () {
557- bool current = false ;
558555 if constexpr (arity != 2 )
559556 return false ;
560557 if constexpr (arg_num == 0 ) {
561558 using SelectedType = std::tuple_element_t <arg_num, TupleLike>;
562- if constexpr (std::is_same_v<float , SelectedType>)
559+ if constexpr (std::is_same_v<FirstParamTy , SelectedType>)
563560 return check_binary_functor_types_for_specialization<
564561 TupleLike,
565562 FirstParamTy,
@@ -568,7 +565,7 @@ struct check_binary_functor_types_for_specialization {
568565 arg_num + 1 >::check ();
569566 } else if constexpr (arg_num == 1 ) {
570567 using SelectedType2 = std::tuple_element_t <arg_num, TupleLike>;
571- if constexpr (std::is_same_v<float , SelectedType2>)
568+ if constexpr (std::is_same_v<SecondParamTy , SelectedType2>)
572569 return check_binary_functor_types_for_specialization<
573570 TupleLike,
574571 FirstParamTy,
@@ -613,30 +610,91 @@ struct check_binary_functor_types_for_specialization<
613610};
614611
615612// The following is a list of type specializations for vectorized_templated
616- // elementwise kernel. It refers to the first and second runtime types of the
617- // arguments of a binary functor.
618- constexpr int number_of_binary_specializations = 4 ;
619- const std::
620- array<std::array<c10::ScalarType, 2 >, number_of_binary_specializations>
621- rt_binary_specializations = {
622- {{c10::CppTypeToScalarType<float >::value,
623- c10::CppTypeToScalarType<BFloat16>::value},
624- {c10::CppTypeToScalarType<BFloat16>::value,
625- c10::CppTypeToScalarType<float >::value},
626- {c10::CppTypeToScalarType<float >::value,
627- c10::CppTypeToScalarType<Half>::value},
628- {c10::CppTypeToScalarType<Half>::value,
629- c10::CppTypeToScalarType<float >::value}}};
613+ // elementwise kernel. The three types refer to runtime types of the output
614+ // tensor, first tensor argument, and the second tensor argument used for a
615+ // binary functor.
616+ constexpr std::array rt_binary_specializations = {
617+ std::array<c10::ScalarType, 3 >(
618+ {c10::CppTypeToScalarType<float >::value,
619+ c10::CppTypeToScalarType<float >::value,
620+ c10::CppTypeToScalarType<BFloat16>::value}),
621+ std::array<c10::ScalarType, 3 >(
622+ {c10::CppTypeToScalarType<float >::value,
623+ c10::CppTypeToScalarType<BFloat16>::value,
624+ c10::CppTypeToScalarType<float >::value}),
625+ std::array<c10::ScalarType, 3 >(
626+ {c10::CppTypeToScalarType<BFloat16>::value,
627+ c10::CppTypeToScalarType<BFloat16>::value,
628+ c10::CppTypeToScalarType<float >::value}),
629+ std::array<c10::ScalarType, 3 >(
630+ {c10::CppTypeToScalarType<float >::value,
631+ c10::CppTypeToScalarType<float >::value,
632+ c10::CppTypeToScalarType<Half>::value}),
633+ std::array<c10::ScalarType, 3 >(
634+ {c10::CppTypeToScalarType<float >::value,
635+ c10::CppTypeToScalarType<Half>::value,
636+ c10::CppTypeToScalarType<float >::value}),
637+ std::array<c10::ScalarType, 3 >(
638+ {c10::CppTypeToScalarType<Half>::value,
639+ c10::CppTypeToScalarType<Half>::value,
640+ c10::CppTypeToScalarType<float >::value})};
630641
631642bool check_binary_rt_types_for_specialization (TensorIteratorBase& iter) {
632643 if (iter.ninputs () != 2 )
633644 return false ;
634- for (int i = 0 ; i < 4 ; i++ )
635- if (iter.input_dtype (0 ) == rt_binary_specializations[i][ 0 ] &&
636- iter.input_dtype (1 ) == rt_binary_specializations[i][ 1 ])
645+ for (auto spec : rt_binary_specializations )
646+ if (iter.dtype (0 ) == spec[ 0 ] && iter. input_dtype ( 0 ) == spec[ 1 ] &&
647+ iter.input_dtype (1 ) == spec[ 2 ])
637648 return true ;
638649 return false ;
639650}
651+
652+ template <int arg_index>
653+ struct type_specialized_kernel_launcher {
654+ template <
655+ typename func_t ,
656+ typename array_t ,
657+ typename inp_calc_t ,
658+ typename out_calc_t ,
659+ typename loader_t ,
660+ typename storer_t >
661+ static void apply (
662+ ScalarType ret_t ,
663+ ScalarType arg0_t ,
664+ ScalarType arg1_t ,
665+ int64_t numel,
666+ func_t f,
667+ array_t data,
668+ inp_calc_t input_offset_calculator,
669+ out_calc_t output_offset_calculator,
670+ loader_t loader,
671+ storer_t storer) {
672+ if (ret_t == rt_binary_specializations[arg_index][0 ] &&
673+ arg0_t == rt_binary_specializations[arg_index][1 ] &&
674+ arg1_t == rt_binary_specializations[arg_index][2 ])
675+ launch_vectorized_templated_kernel<
676+ func_t ,
677+ array_t ,
678+ inp_calc_t ,
679+ out_calc_t ,
680+ loader_t ,
681+ storer_t ,
682+ decltype (c10::impl::ScalarTypeToCPPType<
683+ rt_binary_specializations[arg_index][0 ]>::t),
684+ decltype (c10::impl::ScalarTypeToCPPType<
685+ rt_binary_specializations[arg_index][1 ]>::t),
686+ decltype (c10::impl::ScalarTypeToCPPType<
687+ rt_binary_specializations[arg_index][2 ]>::t)>(
688+ numel,
689+ f,
690+ data,
691+ input_offset_calculator,
692+ output_offset_calculator,
693+ loader,
694+ storer);
695+ }
696+ };
697+
640698} // namespace
641699#endif
642700
@@ -666,10 +724,10 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
666724#ifdef USE_ROCM
667725 // Attempt to call specialized vectorized elementwise kernel
668726 // that enables interleaving.
669- if (false && check_binary_rt_types_for_specialization (iter) &&
727+ if (check_binary_rt_types_for_specialization (iter) &&
670728 memory::can_vectorize_up_to<func_t >(data) > 1 ) {
671- // constexpr to reduce the amount of kernels (empty) generated for
672- // unrolled templated elementwise and limit which functors are actually
729+ // constexpr to reduce the amount of kernels generated for
730+ // vectorized templated elementwise and limit which functors are actually
673731 // applied to the load and store at compile time.
674732 using func_tuple = typename traits::ArgsTuple;
675733 if constexpr (
@@ -679,7 +737,7 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
679737 float ,
680738 float ,
681739 traits::arity,
682- /* current =*/ 0 >::check ()) {
740+ /* arg_num =*/ 0 >::check ()) {
683741 // If we got here, we know we are in one of the specialized cases. We
684742 // need to translate the runtime type to a statically known type. This
685743 // is effectively hoisting to the host the switch over runtime type in
@@ -689,90 +747,24 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
689747 auto output_offset_calculator = TrivialOffsetCalculator<1 >();
690748 auto loader = memory::LoadWithCast<traits::arity>(iter);
691749 auto storer = memory::StoreWithCast<1 >(iter);
692- if (iter.input_dtype (0 ) == c10::CppTypeToScalarType<float >::value &&
693- iter.input_dtype (1 ) == c10::CppTypeToScalarType<BFloat16>::value)
694- launch_vectorized_templated_kernel<
695- func_t ,
696- at::detail::Array<char *, ntensors>,
697- decltype (input_offset_calculator),
698- decltype (output_offset_calculator),
699- decltype (loader),
700- decltype (storer),
701- float ,
702- float ,
703- BFloat16>(
704- numel,
705- f,
706- data,
707- input_offset_calculator,
708- output_offset_calculator,
709- loader,
710- storer);
711- else if (
712- iter.input_dtype (0 ) == c10::CppTypeToScalarType<BFloat16>::value &&
713- iter.input_dtype (1 ) == c10::CppTypeToScalarType<float >::value)
714- launch_vectorized_templated_kernel<
715- func_t ,
716- at::detail::Array<char *, ntensors>,
717- decltype (input_offset_calculator),
718- decltype (output_offset_calculator),
719- decltype (loader),
720- decltype (storer),
721- float ,
722- BFloat16,
723- float >(
724- numel,
725- f,
726- data,
727- input_offset_calculator,
728- output_offset_calculator,
729- loader,
730- storer);
731- else if (
732- iter.input_dtype (0 ) == c10::CppTypeToScalarType<float >::value &&
733- iter.input_dtype (1 ) == c10::CppTypeToScalarType<Half>::value)
734- launch_vectorized_templated_kernel<
735- func_t ,
736- at::detail::Array<char *, ntensors>,
737- decltype (input_offset_calculator),
738- decltype (output_offset_calculator),
739- decltype (loader),
740- decltype (storer),
741- float ,
742- float ,
743- Half>(
744- numel,
745- f,
746- data,
747- input_offset_calculator,
748- output_offset_calculator,
749- loader,
750- storer);
751- else if (
752- iter.input_dtype (0 ) == c10::CppTypeToScalarType<Half>::value &&
753- iter.input_dtype (1 ) == c10::CppTypeToScalarType<float >::value)
754- launch_vectorized_templated_kernel<
755- func_t ,
756- at::detail::Array<char *, ntensors>,
757- decltype (input_offset_calculator),
758- decltype (output_offset_calculator),
759- decltype (loader),
760- decltype (storer),
761- float ,
762- Half,
763- float >(
764- numel,
765- f,
766- data,
767- input_offset_calculator,
768- output_offset_calculator,
769- loader,
770- storer);
771- else
772- TORCH_CHECK (false , " unreachable" );
750+ memory::detail::static_unroll<
751+ type_specialized_kernel_launcher,
752+ rt_binary_specializations.size ()>::
753+ with_args (
754+ iter.dtype (0 ),
755+ iter.input_dtype (0 ),
756+ iter.input_dtype (1 ),
757+ numel,
758+ f,
759+ data,
760+ input_offset_calculator,
761+ output_offset_calculator,
762+ loader,
763+ storer);
773764 return ;
774765 }
775766 }
767+
776768 at::detail::Array<ScalarType, ntensors> dtypes;
777769 auto inner_strides = iter.get_inner_strides ();
778770 at::detail::Array<int , ntensors> strides;
0 commit comments