File tree Expand file tree Collapse file tree 2 files changed +18
-11
lines changed 
dpctl/tensor/libtensor/include/kernels/elementwise_functions Expand file tree Collapse file tree 2 files changed +18
-11
lines changed Original file line number Diff line number Diff line change @@ -114,21 +114,22 @@ template <typename argT, typename resT> struct Expm1Functor
114114            }
115115
116116            //  x, y finite numbers
117-             realT cosY_val;
118-             auto  cosY_val_multi_ptr = sycl::address_space_cast<
119-                 sycl::access::address_space::private_space,
120-                 sycl::access::decorated::yes>(&cosY_val);
121-             const  realT sinY_val = sycl::sincos (y, cosY_val_multi_ptr);
122-             const  realT sinhalfY_val = std::sin (y / 2 );
117+             const  realT cosY_val = std::cos (y);
118+             const  realT sinY_val = (y == 0 ) ? y : std::sin (y);
119+             const  realT sinhalfY_val = (y == 0 ) ? y : std::sin (y / 2 );
123120
124121            const  realT res_re =
125122                std::expm1 (x) * cosY_val - 2  * sinhalfY_val * sinhalfY_val;
126-             const   realT res_im = std::exp (x) * sinY_val;
123+             realT res_im = std::exp (x) * sinY_val;
127124            return  resT{res_re, res_im};
128125        }
129126        else  {
130127            static_assert (std::is_floating_point_v<argT> ||
131128                          std::is_same_v<argT, sycl::half>);
129+             static_assert (std::is_same_v<argT, resT>);
130+             if  (in == 0 ) {
131+                 return  in;
132+             }
132133            return  std::expm1 (in);
133134        }
134135    }
Original file line number Diff line number Diff line change @@ -81,11 +81,15 @@ template <typename argT, typename resT> struct SinFunctor
8181             */  
8282            if  (in_re_finite && in_im_finite) {
8383#ifdef  USE_SYCL_FOR_COMPLEX_TYPES
84-                 return  exprm_ns::sin (
84+                 resT res =  exprm_ns::sin (
8585                    exprm_ns::complex <realT>(in)); //  std::sin(in);
8686#else 
87-                 return  std::sin (in);
87+                 resT res =  std::sin (in);
8888#endif 
89+                 if  (in_re == realT (0 )) {
90+                     res.real (std::copysign (realT (0 ), in_re));
91+                 }
92+                 return  res;
8993            }
9094
9195            /* 
@@ -176,8 +180,10 @@ template <typename argT, typename resT> struct SinFunctor
176180            return  resT{sinh_im, -sinh_re};
177181        }
178182        else  {
179-             static_assert (std::is_floating_point_v<argT> ||
180-                           std::is_same_v<argT, sycl::half>);
183+             static_assert (std::is_same_v<argT, resT>);
184+             if  (in == 0 ) {
185+                 return  in;
186+             }
181187            return  std::sin (in);
182188        }
183189    }
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments