Skip to content

Commit

Permalink
Add refract function, fix reflect functions and small optimization to…
Browse files Browse the repository at this point in the history
… lerp
  • Loading branch information
redorav committed Mar 27, 2019
1 parent 4d88f58 commit ec88eb1
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 29 deletions.
92 changes: 77 additions & 15 deletions src/hlsl++.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,11 +301,10 @@ namespace hlslpp
hlslpp_inline n128 _hlslpp_lrp_ps(n128 x, n128 y, n128 a)
{
// Slower
// n128 y_minus_x = Masksub_ps(y, x);
// n128 y_minus_x = _hlslpp_sub_ps(y, x);
// n128 result = _hlslpp_madd_ps(y_minus_x, a, x);

n128 one_minus_a = _hlslpp_sub_ps(f4_1, a);
n128 x_one_minus_a = _hlslpp_mul_ps(x, one_minus_a);
n128 x_one_minus_a = _hlslpp_msub_ps(x, x, a); // x * (1 - a)
n128 result = _hlslpp_madd_ps(y, a, x_one_minus_a);
return result;
}
Expand Down Expand Up @@ -619,15 +618,73 @@ namespace hlslpp
return result;
}

// Auxiliary dot3 that adds, subtracts, adds instead of adding all
hlslpp_inline n128 _hlslpp_dot3_asa_ps(n128 x, n128 y)
// https://docs.microsoft.com/en-us/windows/desktop/direct3dhlsl/dx-graphics-hlsl-reflect
// v = i - 2 * n * dot(i, n)

hlslpp_inline n128 _hlslpp_reflect1_ps(n128 i, n128 n)
{
n128 multi = _hlslpp_mul_ps(x, y); // Multiply components together
n128 shuf1 = _hlslpp_perm_yyyy_ps(multi); // Move y into x
n128 add1 = _hlslpp_sub_ps(multi, shuf1); // Contains x-y, _, _, _
n128 shuf2 = _hlslpp_perm_zzzz_ps(multi); // Move z into x
n128 result = _hlslpp_add_ss(add1, shuf2); // Contains x-y+z, _, _, _
return result;
return _hlslpp_sub_ps(i, _hlslpp_mul_ps(f4_2, _hlslpp_mul_ps(n, _hlslpp_mul_ps(i, n))));
}

hlslpp_inline n128 _hlslpp_reflect2_ps(n128 i, n128 n)
{
return _hlslpp_sub_ps(i, _hlslpp_mul_ps(f4_2, _hlslpp_mul_ps(n, _hlslpp_perm_xxxx_ps(_hlslpp_dot2_ps(i, n)))));
}

hlslpp_inline n128 _hlslpp_reflect3_ps(n128 i, n128 n)
{
return _hlslpp_sub_ps(i, _hlslpp_mul_ps(f4_2, _hlslpp_mul_ps(n, _hlslpp_perm_xxxx_ps(_hlslpp_dot3_ps(i, n)))));
}

hlslpp_inline n128 _hlslpp_reflect4_ps(n128 i, n128 n)
{
return _hlslpp_sub_ps(i, _hlslpp_mul_ps(f4_2, _hlslpp_mul_ps(n, _hlslpp_perm_xxxx_ps(_hlslpp_dot4_ps(i, n)))));
}

// https://www.khronos.org/registry/OpenGL-Refpages/gl4/html/refract.xhtml
//
// k = 1.0 - ior * ior * (1.0 - dot(n, i) * dot(n, i));
// if (k < 0.0)
// return 0.0;
// else
// return ior * i - (ior * dot(n, i) + sqrt(k)) * n;

hlslpp_inline n128 _hlslpp_refract_ps(n128 i, n128 n, n128 ior, n128 NdotI)
{
NdotI = _hlslpp_perm_xxxx_ps(NdotI); // Propagate to all components (dot lives in x)

n128 ior2 = _hlslpp_mul_ps(ior, ior); // ior^2
n128 invNdotI2 = _hlslpp_subm_ps(f4_1, NdotI, NdotI); // 1.0 - dot(n, i)^2
n128 k = _hlslpp_subm_ps(f4_1, ior2, invNdotI2); // k = 1.0 - ior^2 * (1.0 - dot(n, i)^2)

n128 sqrtK = _hlslpp_sqrt_ps(k); // sqrt(k)
n128 iorNdotISqrtk = _hlslpp_madd_ps(ior, NdotI, sqrtK); // ior * dot(n, i) + sqrt(k)
n128 iorNdotISqrtkn = _hlslpp_mul_ps(iorNdotISqrtk, n); // (ior * dot(n, i) + sqrt(k)) * n
n128 result = _hlslpp_msub_ps(ior, i, iorNdotISqrtkn); // ior * i - (ior * dot(n, i) + sqrt(k)) * n

n128 klt0 = _hlslpp_cmplt_ps(k, _hlslpp_setzero_ps()); // Whether k was less than 0

return _hlslpp_sel_ps(result, _hlslpp_setzero_ps(), klt0); // Select between 0 and the result
}

hlslpp_inline n128 _hlslpp_refract1_ps(n128 i, n128 n, n128 ior)
{
return _hlslpp_refract_ps(i, n, ior, _hlslpp_mul_ps(i, n));
}

hlslpp_inline n128 _hlslpp_refract2_ps(n128 i, n128 n, n128 ior)
{
return _hlslpp_refract_ps(i, n, _hlslpp_perm_xxxx_ps(ior), _hlslpp_dot2_ps(i, n));
}

hlslpp_inline n128 _hlslpp_refract3_ps(n128 i, n128 n, n128 ior)
{
return _hlslpp_refract_ps(i, n, _hlslpp_perm_xxxx_ps(ior), _hlslpp_dot3_ps(i, n));
}

hlslpp_inline n128 _hlslpp_refract4_ps(n128 i, n128 n, n128 ior)
{
return _hlslpp_refract_ps(i, n, _hlslpp_perm_xxxx_ps(ior), _hlslpp_dot4_ps(i, n));
}

// Returns true if x is not +infinity or -infinity
Expand Down Expand Up @@ -1316,10 +1373,15 @@ namespace hlslpp
float3 rcp(const float3& f) { return float3(_hlslpp_rcp_ps(f.vec)); }
float4 rcp(const float4& f) { return float4(_hlslpp_rcp_ps(f.vec)); }

float1 reflect(const float1& i, const float1& n) { return float1(_hlslpp_sub_ps(i.vec, _hlslpp_mul_ps(f4_2, _hlslpp_mul_ps(n.vec, _hlslpp_perm_xxxx_ps(_hlslpp_mul_ps(i.vec, n.vec)))))); }
float2 reflect(const float2& i, const float2& n) { return float2(_hlslpp_sub_ps(i.vec, _hlslpp_mul_ps(f4_2, _hlslpp_mul_ps(n.vec, _hlslpp_perm_xxxx_ps(_hlslpp_mul_ps(i.vec, n.vec)))))); }
float3 reflect(const float3& i, const float3& n) { return float3(_hlslpp_sub_ps(i.vec, _hlslpp_mul_ps(f4_2, _hlslpp_mul_ps(n.vec, _hlslpp_perm_xxxx_ps(_hlslpp_mul_ps(i.vec, n.vec)))))); }
float4 reflect(const float4& i, const float4& n) { return float4(_hlslpp_sub_ps(i.vec, _hlslpp_mul_ps(f4_2, _hlslpp_mul_ps(n.vec, _hlslpp_perm_xxxx_ps(_hlslpp_mul_ps(i.vec, n.vec)))))); }
float1 reflect(const float1& i, const float1& n) { return float1(_hlslpp_reflect1_ps(i.vec, n.vec)); }
float2 reflect(const float2& i, const float2& n) { return float2(_hlslpp_reflect2_ps(i.vec, n.vec)); }
float3 reflect(const float3& i, const float3& n) { return float3(_hlslpp_reflect3_ps(i.vec, n.vec)); }
float4 reflect(const float4& i, const float4& n) { return float4(_hlslpp_reflect4_ps(i.vec, n.vec)); }

float1 refract(const float1& i, const float1& n, const float1& ior) { return float1(_hlslpp_refract1_ps(i.vec, n.vec, ior.vec)); }
float2 refract(const float2& i, const float2& n, const float1& ior) { return float2(_hlslpp_refract2_ps(i.vec, n.vec, ior.vec)); }
float3 refract(const float3& i, const float3& n, const float1& ior) { return float3(_hlslpp_refract3_ps(i.vec, n.vec, ior.vec)); }
float4 refract(const float4& i, const float4& n, const float1& ior) { return float4(_hlslpp_refract4_ps(i.vec, n.vec, ior.vec)); }

float1 rsqrt(const float1& f) { return float1(_hlslpp_rsqrt_ps(f.vec)); }
float2 rsqrt(const float2& f) { return float2(_hlslpp_rsqrt_ps(f.vec)); }
Expand Down
4 changes: 2 additions & 2 deletions src/hlsl++_sse.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ typedef __m256i n256i;

#else

n128i _hlslpp_sllv_epi32(n128i x, n128i count)
inline n128i _hlslpp_sllv_epi32(n128i x, n128i count)
{
n128i count1 = _hlslpp_perm_epi32(count, HLSLPP_SHUFFLE_MASK(1, 0, 0, 0));
n128i count2 = _hlslpp_perm_epi32(count, HLSLPP_SHUFFLE_MASK(2, 0, 0, 0));
Expand Down Expand Up @@ -283,7 +283,7 @@ n128i _hlslpp_sllv_epi32(n128i x, n128i count)

#else

n128i _hlslpp_srlv_epi32(n128i x, n128i count)
inline n128i _hlslpp_srlv_epi32(n128i x, n128i count)
{
n128i count1 = _hlslpp_perm_epi32(count, HLSLPP_SHUFFLE_MASK(1, 0, 0, 0));
n128i count2 = _hlslpp_perm_epi32(count, HLSLPP_SHUFFLE_MASK(2, 0, 0, 0));
Expand Down
34 changes: 22 additions & 12 deletions src/hlsl++_unit_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -672,20 +672,20 @@ void RunUnitTests()
float3 vneq3 = vfoo3 != vbar3; eq(vneq3, (float)vfoo3.x != (float)vbar3.x, (float)vfoo3.y != (float)vbar3.y, (float)vfoo3.z != (float)vbar3.z);
float4 vneq4 = vfoo4 != vbar4; eq(vneq4, (float)vfoo4.x != (float)vbar4.x, (float)vfoo4.y != (float)vbar4.y, (float)vfoo4.z != (float)vbar4.z, (float)vfoo4.w != (float)vbar4.w);

vfoo1 = -vbar1.r; eq(vfoo1, -vbar1.r);
vfoo2 = -vbar2.gr; eq(vfoo2, (float)-vbar2.g, (float)-vbar2.r);
vfoo3 = -vbar3.bgg; eq(vfoo3, (float)-vbar3.b, (float)-vbar3.g, (float)-vbar3.g);
vfoo4 = -vbar4.rbgr; eq(vfoo4, (float)-vbar4.r, (float)-vbar4.b, (float)-vbar4.g, (float)-vbar4.r);
float1 vneg_1 = -vbar1.r; eq(vneg_1, -vbar1.r);
float2 vneg_2 = -vbar2.gr; eq(vneg_2, (float)-vbar2.g, (float)-vbar2.r);
float3 vneg_3 = -vbar3.bgg; eq(vneg_3, (float)-vbar3.b, (float)-vbar3.g, (float)-vbar3.g);
float4 vneg_4 = -vbar4.rbgr; eq(vneg_4, (float)-vbar4.r, (float)-vbar4.b, (float)-vbar4.g, (float)-vbar4.r);

float1 vabs1 = abs(vfoo1); eq(vabs1, abs((float)vfoo1.x));
float2 vabs2 = abs(vfoo2); eq(vabs2, abs((float)vfoo2.x), abs((float)vfoo2.y));
float3 vabs3 = abs(vfoo3); eq(vabs3, abs((float)vfoo3.x), abs((float)vfoo3.y), abs((float)vfoo3.z));
float4 vabs4 = abs(vfoo4); eq(vabs4, abs((float)vfoo4.x), abs((float)vfoo4.y), abs((float)vfoo4.z), abs((float)vfoo4.w));

vfoo1 = abs(-vfoo1); eq(vabs1, abs((float)-vfoo1));
vfoo2 = abs(-vfoo2); eq(vabs2, abs((float)-vfoo2.x), abs((float)-vfoo2.y));
vfoo3 = abs(-vfoo3); eq(vabs3, abs((float)-vfoo3.x), abs((float)-vfoo3.y), abs((float)-vfoo3.z));
vfoo4 = abs(-vfoo4); eq(vabs4, abs((float)-vfoo4.x), abs((float)-vfoo4.y), abs((float)-vfoo4.z), abs((float)-vfoo4.w));
float1 vabsneg_1 = abs(-vfoo1); eq(vabsneg_1, abs((float)-vfoo1));
float2 vabsneg_2 = abs(-vfoo2); eq(vabsneg_2, abs((float)-vfoo2.x), abs((float)-vfoo2.y));
float3 vabsneg_3 = abs(-vfoo3); eq(vabsneg_3, abs((float)-vfoo3.x), abs((float)-vfoo3.y), abs((float)-vfoo3.z));
float4 vabsneg_4 = abs(-vfoo4); eq(vabsneg_4, abs((float)-vfoo4.x), abs((float)-vfoo4.y), abs((float)-vfoo4.z), abs((float)-vfoo4.w));

float1 vabs_swiz_1 = abs(vfoo1.r); eq(vabs_swiz_1, abs((float)vfoo1.x));
float2 vabs_swiz_2 = abs(vfoo2.yx); eq(vabs_swiz_2, abs((float)vfoo2.g), abs((float)vfoo2.r));
Expand Down Expand Up @@ -1009,10 +1009,15 @@ void RunUnitTests()
float3 vmax_swiz_3 = max(vfoo3.gbr, vbar3.xyy);
float4 vmax_swiz_4 = max(vfoo4.brga, vbar4.yxzw);

float1 vnormalize_1 = normalize(vfoo1);
float2 vnormalize_2 = normalize(vfoo2);
float3 vnormalize_3 = normalize(vfoo3);
float4 vnormalize_4 = normalize(vfoo4);
float1 vnormalize_foo_1 = normalize(vfoo1);
float2 vnormalize_foo_2 = normalize(vfoo2);
float3 vnormalize_foo_3 = normalize(vfoo3);
float4 vnormalize_foo_4 = normalize(vfoo4);

float1 vnormalize_bar_1 = normalize(vbar1);
float2 vnormalize_bar_2 = normalize(vbar2);
float3 vnormalize_bar_3 = normalize(vbar3);
float4 vnormalize_bar_4 = normalize(vbar4);

float1 vnormalize_swiz_1 = normalize(vfoo1.r);
float2 vnormalize_swiz_2 = normalize(vfoo2.rg);
Expand Down Expand Up @@ -1069,6 +1074,11 @@ void RunUnitTests()
float3 vreflect_swiz_3_b = reflect(vfoo3.bgr, vbar3.ggr);
float4 vreflect_swiz_4_b = reflect(vfoo4.xxzy, vbar4.wxyy);

float1 vrefract1 = refract(vnormalize_foo_1, vnormalize_bar_1, float1(0.1f));
float2 vrefract2 = refract(vnormalize_foo_2, vnormalize_bar_2, float1(-0.7f));
float3 vrefract3 = refract(vnormalize_foo_3, vnormalize_bar_3, float1(0.1f));
float4 vrefract4 = refract(vnormalize_foo_4, vnormalize_bar_4, float1(0.1f));

float1 vrsqrt1 = rsqrt(vfoo1);
float2 vrsqrt2 = rsqrt(vfoo2);
float3 vrsqrt3 = rsqrt(vfoo3);
Expand Down

0 comments on commit ec88eb1

Please sign in to comment.